Răsfoiți Sursa

add datetime var comparison operations (#4406)

* add datetime var operations

* add future annotations

* add LiteralDatetimeVar

* remove methods that don't apply

* fix serialization

* add unit and integrations test

* oops, forgot to commit that important change
Thomas Brandého 5 luni în urmă
părinte
comite
76ce112002

+ 1 - 0
reflex/vars/__init__.py

@@ -9,6 +9,7 @@ from .base import get_unique_variable_name as get_unique_variable_name
 from .base import get_uuid_string_var as get_uuid_string_var
 from .base import var_operation as var_operation
 from .base import var_operation_return as var_operation_return
+from .datetime import DateTimeVar as DateTimeVar
 from .function import FunctionStringVar as FunctionStringVar
 from .function import FunctionVar as FunctionVar
 from .function import VarOperationCall as VarOperationCall

+ 222 - 0
reflex/vars/datetime.py

@@ -0,0 +1,222 @@
+"""Immutable datetime and date vars."""
+
+from __future__ import annotations
+
+import dataclasses
+import sys
+from datetime import date, datetime
+from typing import Any, NoReturn, TypeVar, Union, overload
+
+from reflex.utils.exceptions import VarTypeError
+from reflex.vars.number import BooleanVar
+
+from .base import (
+    CustomVarOperationReturn,
+    LiteralVar,
+    Var,
+    VarData,
+    var_operation,
+    var_operation_return,
+)
+
+DATETIME_T = TypeVar("DATETIME_T", datetime, date)
+
+datetime_types = Union[datetime, date]
+
+
+def raise_var_type_error():
+    """Raise a VarTypeError.
+
+    Raises:
+        VarTypeError: Cannot compare a datetime object with a non-datetime object.
+    """
+    raise VarTypeError("Cannot compare a datetime object with a non-datetime object.")
+
+
+class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)):
+    """A variable that holds a datetime or date object."""
+
+    @overload
+    def __lt__(self, other: datetime_types) -> BooleanVar: ...
+
+    @overload
+    def __lt__(self, other: NoReturn) -> NoReturn: ...
+
+    def __lt__(self, other: Any):
+        """Less than comparison.
+
+        Args:
+            other: The other datetime to compare.
+
+        Returns:
+            The result of the comparison.
+        """
+        if not isinstance(other, DATETIME_TYPES):
+            raise_var_type_error()
+        return date_lt_operation(self, other)
+
+    @overload
+    def __le__(self, other: datetime_types) -> BooleanVar: ...
+
+    @overload
+    def __le__(self, other: NoReturn) -> NoReturn: ...
+
+    def __le__(self, other: Any):
+        """Less than or equal comparison.
+
+        Args:
+            other: The other datetime to compare.
+
+        Returns:
+            The result of the comparison.
+        """
+        if not isinstance(other, DATETIME_TYPES):
+            raise_var_type_error()
+        return date_le_operation(self, other)
+
+    @overload
+    def __gt__(self, other: datetime_types) -> BooleanVar: ...
+
+    @overload
+    def __gt__(self, other: NoReturn) -> NoReturn: ...
+
+    def __gt__(self, other: Any):
+        """Greater than comparison.
+
+        Args:
+            other: The other datetime to compare.
+
+        Returns:
+            The result of the comparison.
+        """
+        if not isinstance(other, DATETIME_TYPES):
+            raise_var_type_error()
+        return date_gt_operation(self, other)
+
+    @overload
+    def __ge__(self, other: datetime_types) -> BooleanVar: ...
+
+    @overload
+    def __ge__(self, other: NoReturn) -> NoReturn: ...
+
+    def __ge__(self, other: Any):
+        """Greater than or equal comparison.
+
+        Args:
+            other: The other datetime to compare.
+
+        Returns:
+            The result of the comparison.
+        """
+        if not isinstance(other, DATETIME_TYPES):
+            raise_var_type_error()
+        return date_ge_operation(self, other)
+
+
+@var_operation
+def date_gt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
+    """Greater than comparison.
+
+    Args:
+        lhs: The left-hand side of the operation.
+        rhs: The right-hand side of the operation.
+
+    Returns:
+        The result of the operation.
+    """
+    return date_compare_operation(rhs, lhs, strict=True)
+
+
+@var_operation
+def date_lt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
+    """Less than comparison.
+
+    Args:
+        lhs: The left-hand side of the operation.
+        rhs: The right-hand side of the operation.
+
+    Returns:
+        The result of the operation.
+    """
+    return date_compare_operation(lhs, rhs, strict=True)
+
+
+@var_operation
+def date_le_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
+    """Less than or equal comparison.
+
+    Args:
+        lhs: The left-hand side of the operation.
+        rhs: The right-hand side of the operation.
+
+    Returns:
+        The result of the operation.
+    """
+    return date_compare_operation(lhs, rhs)
+
+
+@var_operation
+def date_ge_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
+    """Greater than or equal comparison.
+
+    Args:
+        lhs: The left-hand side of the operation.
+        rhs: The right-hand side of the operation.
+
+    Returns:
+        The result of the operation.
+    """
+    return date_compare_operation(rhs, lhs)
+
+
+def date_compare_operation(
+    lhs: DateTimeVar[DATETIME_T] | Any,
+    rhs: DateTimeVar[DATETIME_T] | Any,
+    strict: bool = False,
+) -> CustomVarOperationReturn:
+    """Check if the value is less than the other value.
+
+    Args:
+        lhs: The left-hand side of the operation.
+        rhs: The right-hand side of the operation.
+        strict: Whether to use strict comparison.
+
+    Returns:
+        The result of the operation.
+    """
+    return var_operation_return(
+        f"({lhs} { '<' if strict else '<='} {rhs})",
+        bool,
+    )
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class LiteralDatetimeVar(LiteralVar, DateTimeVar):
+    """Base class for immutable datetime and date vars."""
+
+    _var_value: datetime | date = dataclasses.field(default=datetime.now())
+
+    @classmethod
+    def create(cls, value: datetime | date, _var_data: VarData | None = None):
+        """Create a new instance of the class.
+
+        Args:
+            value: The value to set.
+
+        Returns:
+            LiteralDatetimeVar: The new instance of the class.
+        """
+        js_expr = f'"{str(value)}"'
+        return cls(
+            _js_expr=js_expr,
+            _var_type=type(value),
+            _var_value=value,
+            _var_data=_var_data,
+        )
+
+
+DATETIME_TYPES = (datetime, date, DateTimeVar)

+ 87 - 0
tests/integration/tests_playwright/test_datetime_operations.py

@@ -0,0 +1,87 @@
+from typing import Generator
+
+import pytest
+from playwright.sync_api import Page, expect
+
+from reflex.testing import AppHarness
+
+
+def DatetimeOperationsApp():
+    from datetime import datetime
+
+    import reflex as rx
+
+    class DtOperationsState(rx.State):
+        date1: datetime = datetime(2021, 1, 1)
+        date2: datetime = datetime(2031, 1, 1)
+        date3: datetime = datetime(2021, 1, 1)
+
+    app = rx.App(state=DtOperationsState)
+
+    @app.add_page
+    def index():
+        return rx.vstack(
+            rx.text(DtOperationsState.date1, id="date1"),
+            rx.text(DtOperationsState.date2, id="date2"),
+            rx.text(DtOperationsState.date3, id="date3"),
+            rx.text("Operations between date1 and date2"),
+            rx.text(DtOperationsState.date1 == DtOperationsState.date2, id="1_eq_2"),
+            rx.text(DtOperationsState.date1 != DtOperationsState.date2, id="1_neq_2"),
+            rx.text(DtOperationsState.date1 < DtOperationsState.date2, id="1_lt_2"),
+            rx.text(DtOperationsState.date1 <= DtOperationsState.date2, id="1_le_2"),
+            rx.text(DtOperationsState.date1 > DtOperationsState.date2, id="1_gt_2"),
+            rx.text(DtOperationsState.date1 >= DtOperationsState.date2, id="1_ge_2"),
+            rx.text("Operations between date1 and date3"),
+            rx.text(DtOperationsState.date1 == DtOperationsState.date3, id="1_eq_3"),
+            rx.text(DtOperationsState.date1 != DtOperationsState.date3, id="1_neq_3"),
+            rx.text(DtOperationsState.date1 < DtOperationsState.date3, id="1_lt_3"),
+            rx.text(DtOperationsState.date1 <= DtOperationsState.date3, id="1_le_3"),
+            rx.text(DtOperationsState.date1 > DtOperationsState.date3, id="1_gt_3"),
+            rx.text(DtOperationsState.date1 >= DtOperationsState.date3, id="1_ge_3"),
+        )
+
+
+@pytest.fixture()
+def datetime_operations_app(tmp_path_factory) -> Generator[AppHarness, None, None]:
+    """Start Table app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path_factory: pytest tmp_path_factory fixture
+
+    Yields:
+        running AppHarness instance
+
+    """
+    with AppHarness.create(
+        root=tmp_path_factory.mktemp("datetime_operations_app"),
+        app_source=DatetimeOperationsApp,  # type: ignore
+    ) as harness:
+        assert harness.app_instance is not None, "app is not running"
+        yield harness
+
+
+def test_datetime_operations(datetime_operations_app: AppHarness, page: Page):
+    assert datetime_operations_app.frontend_url is not None
+
+    page.goto(datetime_operations_app.frontend_url)
+    expect(page).to_have_url(datetime_operations_app.frontend_url + "/")
+    # Check the actual values
+    expect(page.locator("id=date1")).to_have_text("2021-01-01 00:00:00")
+    expect(page.locator("id=date2")).to_have_text("2031-01-01 00:00:00")
+    expect(page.locator("id=date3")).to_have_text("2021-01-01 00:00:00")
+
+    # Check the operations between date1 and date2
+    expect(page.locator("id=1_eq_2")).to_have_text("false")
+    expect(page.locator("id=1_neq_2")).to_have_text("true")
+    expect(page.locator("id=1_lt_2")).to_have_text("true")
+    expect(page.locator("id=1_le_2")).to_have_text("true")
+    expect(page.locator("id=1_gt_2")).to_have_text("false")
+    expect(page.locator("id=1_ge_2")).to_have_text("false")
+
+    # Check the operations between date1 and date3
+    expect(page.locator("id=1_eq_3")).to_have_text("true")
+    expect(page.locator("id=1_neq_3")).to_have_text("false")
+    expect(page.locator("id=1_lt_3")).to_have_text("false")
+    expect(page.locator("id=1_le_3")).to_have_text("true")
+    expect(page.locator("id=1_gt_3")).to_have_text("false")
+    expect(page.locator("id=1_ge_3")).to_have_text("true")

+ 1 - 0
tests/units/utils/test_serializers.py

@@ -222,6 +222,7 @@ def test_serialize(value: Any, expected: str):
             '"2021-01-01 01:01:01.000001"',
             True,
         ),
+        (datetime.date(2021, 1, 1), '"2021-01-01"', True),
         (Color(color="slate", shade=1), '"var(--slate-1)"', True),
         (BaseSubclass, '"BaseSubclass"', True),
         (Path("."), '"."', True),