Prechádzať zdrojové kódy

improve object var symantics (#4290)

* improve object var symantics

* add case for serializers

* check against serializer with to = dict

* add tests

* fix typing issues

* remove default value

* older version of python doesn't have assert type

* add base to rx field cases

* get it from typing_extension
Khaleel Al-Adhami 6 mesiacov pred
rodič
commit
b5d1e03de1

+ 25 - 4
reflex/utils/serializers.py

@@ -78,7 +78,7 @@ def serializer(
         )
 
     # Apply type transformation if requested
-    if to is not None:
+    if to is not None or ((to := type_hints.get("return")) is not None):
         SERIALIZER_TYPES[type_] = to
         get_serializer_type.cache_clear()
 
@@ -189,16 +189,37 @@ def get_serializer_type(type_: Type) -> Optional[Type]:
     return None
 
 
-def has_serializer(type_: Type) -> bool:
+def has_serializer(type_: Type, into_type: Type | None = None) -> bool:
     """Check if there is a serializer for the type.
 
     Args:
         type_: The type to check.
+        into_type: The type to serialize into.
 
     Returns:
         Whether there is a serializer for the type.
     """
-    return get_serializer(type_) is not None
+    serializer_for_type = get_serializer(type_)
+    return serializer_for_type is not None and (
+        into_type is None or get_serializer_type(type_) == into_type
+    )
+
+
+def can_serialize(type_: Type, into_type: Type | None = None) -> bool:
+    """Check if there is a serializer for the type.
+
+    Args:
+        type_: The type to check.
+        into_type: The type to serialize into.
+
+    Returns:
+        Whether there is a serializer for the type.
+    """
+    return has_serializer(type_, into_type) or (
+        isinstance(type_, type)
+        and dataclasses.is_dataclass(type_)
+        and (into_type is None or into_type is dict)
+    )
 
 
 @serializer(to=str)
@@ -214,7 +235,7 @@ def serialize_type(value: type) -> str:
     return value.__name__
 
 
-@serializer
+@serializer(to=dict)
 def serialize_base(value: Base) -> dict:
     """Serialize a Base instance.
 

+ 48 - 34
reflex/vars/base.py

@@ -75,7 +75,6 @@ from reflex.utils.types import (
 if TYPE_CHECKING:
     from reflex.state import BaseState
 
-    from .function import FunctionVar
     from .number import (
         BooleanVar,
         NumberVar,
@@ -279,6 +278,24 @@ def _decode_var_immutable(value: str) -> tuple[VarData | None, str]:
     return VarData.merge(*var_datas) if var_datas else None, value
 
 
+def can_use_in_object_var(cls: GenericType) -> bool:
+    """Check if the class can be used in an ObjectVar.
+
+    Args:
+        cls: The class to check.
+
+    Returns:
+        Whether the class can be used in an ObjectVar.
+    """
+    if types.is_union(cls):
+        return all(can_use_in_object_var(t) for t in types.get_args(cls))
+    return (
+        inspect.isclass(cls)
+        and not issubclass(cls, Var)
+        and serializers.can_serialize(cls, dict)
+    )
+
+
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
@@ -565,36 +582,33 @@ class Var(Generic[VAR_TYPE]):
         # Encode the _var_data into the formatted output for tracking purposes.
         return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._js_expr}"
 
-    @overload
-    def to(self, output: Type[StringVar]) -> StringVar: ...
-
     @overload
     def to(self, output: Type[str]) -> StringVar: ...
 
     @overload
-    def to(self, output: Type[BooleanVar]) -> BooleanVar: ...
+    def to(self, output: Type[bool]) -> BooleanVar: ...
 
     @overload
-    def to(
-        self, output: Type[NumberVar], var_type: type[int] | type[float] = float
-    ) -> NumberVar: ...
+    def to(self, output: type[int] | type[float]) -> NumberVar: ...
 
     @overload
     def to(
         self,
-        output: Type[ArrayVar],
-        var_type: type[list] | type[tuple] | type[set] = list,
+        output: type[list] | type[tuple] | type[set],
     ) -> ArrayVar: ...
 
     @overload
     def to(
-        self, output: Type[ObjectVar], var_type: types.GenericType = dict
-    ) -> ObjectVar: ...
+        self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE]
+    ) -> ObjectVar[VAR_INSIDE]: ...
 
     @overload
     def to(
-        self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
-    ) -> FunctionVar: ...
+        self, output: Type[ObjectVar], var_type: None = None
+    ) -> ObjectVar[VAR_TYPE]: ...
+
+    @overload
+    def to(self, output: VAR_SUBCLASS, var_type: None = None) -> VAR_SUBCLASS: ...
 
     @overload
     def to(
@@ -630,21 +644,19 @@ class Var(Generic[VAR_TYPE]):
             return get_to_operation(NoneVar).create(self)  # type: ignore
 
         # Handle fixed_output_type being Base or a dataclass.
-        try:
-            if issubclass(fixed_output_type, Base):
-                return self.to(ObjectVar, output)
-        except TypeError:
-            pass
-        if dataclasses.is_dataclass(fixed_output_type) and not issubclass(
-            fixed_output_type, Var
-        ):
+        if can_use_in_object_var(fixed_output_type):
             return self.to(ObjectVar, output)
 
         if inspect.isclass(output):
             for var_subclass in _var_subclasses[::-1]:
                 if issubclass(output, var_subclass.var_subclass):
+                    current_var_type = self._var_type
+                    if current_var_type is Any:
+                        new_var_type = var_type
+                    else:
+                        new_var_type = var_type or current_var_type
                     to_operation_return = var_subclass.to_var_subclass.create(
-                        value=self, _var_type=var_type
+                        value=self, _var_type=new_var_type
                     )
                     return to_operation_return  # type: ignore
 
@@ -707,11 +719,7 @@ class Var(Generic[VAR_TYPE]):
             ):
                 return self.to(NumberVar, self._var_type)
 
-            if all(
-                inspect.isclass(t)
-                and (issubclass(t, Base) or dataclasses.is_dataclass(t))
-                for t in inner_types
-            ):
+            if can_use_in_object_var(var_type):
                 return self.to(ObjectVar, self._var_type)
 
             return self
@@ -730,13 +738,9 @@ class Var(Generic[VAR_TYPE]):
             if issubclass(fixed_type, var_subclass.python_types):
                 return self.to(var_subclass.var_subclass, self._var_type)
 
-        try:
-            if issubclass(fixed_type, Base):
-                return self.to(ObjectVar, self._var_type)
-        except TypeError:
-            pass
-        if dataclasses.is_dataclass(fixed_type):
+        if can_use_in_object_var(fixed_type):
             return self.to(ObjectVar, self._var_type)
+
         return self
 
     def get_default_value(self) -> Any:
@@ -1181,6 +1185,9 @@ class Var(Generic[VAR_TYPE]):
 
 OUTPUT = TypeVar("OUTPUT", bound=Var)
 
+VAR_SUBCLASS = TypeVar("VAR_SUBCLASS", bound=Var)
+VAR_INSIDE = TypeVar("VAR_INSIDE")
+
 
 class ToOperation:
     """A var operation that converts a var to another type."""
@@ -2888,6 +2895,8 @@ def dispatch(
 
 V = TypeVar("V")
 
+BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
+
 
 class Field(Generic[T]):
     """Shadow class for Var to allow for type hinting in the IDE."""
@@ -2924,6 +2933,11 @@ class Field(Generic[T]):
         self: Field[Dict[str, V]], instance: None, owner
     ) -> ObjectVar[Dict[str, V]]: ...
 
+    @overload
+    def __get__(
+        self: Field[BASE_TYPE], instance: None, owner
+    ) -> ObjectVar[BASE_TYPE]: ...
+
     @overload
     def __get__(self, instance: None, owner) -> Var[T]: ...
 

+ 3 - 1
reflex/vars/number.py

@@ -1116,7 +1116,9 @@ U = TypeVar("U")
 
 
 @var_operation
-def ternary_operation(condition: BooleanVar, if_true: Var[T], if_false: Var[U]):
+def ternary_operation(
+    condition: BooleanVar, if_true: Var[T], if_false: Var[U]
+) -> CustomVarOperationReturn[Union[T, U]]:
     """Create a ternary operation.
 
     Args:

+ 30 - 24
reflex/vars/object.py

@@ -36,7 +36,7 @@ from .base import (
 from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
 from .sequence import ArrayVar, StringVar
 
-OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
+OBJECT_TYPE = TypeVar("OBJECT_TYPE")
 
 KEY_TYPE = TypeVar("KEY_TYPE")
 VALUE_TYPE = TypeVar("VALUE_TYPE")
@@ -59,7 +59,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
 
     @overload
     def _value_type(
-        self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
+        self: ObjectVar[Dict[Any, VALUE_TYPE]],
     ) -> Type[VALUE_TYPE]: ...
 
     @overload
@@ -87,7 +87,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
 
     @overload
     def values(
-        self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
+        self: ObjectVar[Dict[Any, VALUE_TYPE]],
     ) -> ArrayVar[List[VALUE_TYPE]]: ...
 
     @overload
@@ -103,7 +103,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
 
     @overload
     def entries(
-        self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
+        self: ObjectVar[Dict[Any, VALUE_TYPE]],
     ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
 
     @overload
@@ -133,47 +133,47 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
     # NoReturn is used here to catch when key value is Any
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
+        self: ObjectVar[Dict[Any, NoReturn]],
         key: Var | Any,
     ) -> Var: ...
 
     @overload
     def __getitem__(
         self: (
-            ObjectVar[Dict[KEY_TYPE, int]]
-            | ObjectVar[Dict[KEY_TYPE, float]]
-            | ObjectVar[Dict[KEY_TYPE, int | float]]
+            ObjectVar[Dict[Any, int]]
+            | ObjectVar[Dict[Any, float]]
+            | ObjectVar[Dict[Any, int | float]]
         ),
         key: Var | Any,
     ) -> NumberVar: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[KEY_TYPE, str]],
+        self: ObjectVar[Dict[Any, str]],
         key: Var | Any,
     ) -> StringVar: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
         key: Var | Any,
     ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
         key: Var | Any,
     ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
+        self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
         key: Var | Any,
     ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
+        self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
         key: Var | Any,
     ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
 
@@ -195,50 +195,56 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
     # NoReturn is used here to catch when key value is Any
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
+        self: ObjectVar[Dict[Any, NoReturn]],
         name: str,
     ) -> Var: ...
 
     @overload
     def __getattr__(
         self: (
-            ObjectVar[Dict[KEY_TYPE, int]]
-            | ObjectVar[Dict[KEY_TYPE, float]]
-            | ObjectVar[Dict[KEY_TYPE, int | float]]
+            ObjectVar[Dict[Any, int]]
+            | ObjectVar[Dict[Any, float]]
+            | ObjectVar[Dict[Any, int | float]]
         ),
         name: str,
     ) -> NumberVar: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[KEY_TYPE, str]],
+        self: ObjectVar[Dict[Any, str]],
         name: str,
     ) -> StringVar: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
         name: str,
     ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
         name: str,
     ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
+        self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
         name: str,
     ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
+        self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
         name: str,
     ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
 
+    @overload
+    def __getattr__(
+        self: ObjectVar,
+        name: str,
+    ) -> ObjectItemOperation: ...
+
     def __getattr__(self, name) -> Var:
         """Get an attribute of the var.
 
@@ -377,8 +383,8 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
     @classmethod
     def create(
         cls,
-        _var_value: OBJECT_TYPE,
-        _var_type: GenericType | None = None,
+        _var_value: dict,
+        _var_type: Type[OBJECT_TYPE] | None = None,
         _var_data: VarData | None = None,
     ) -> LiteralObjectVar[OBJECT_TYPE]:
         """Create the literal object var.

+ 7 - 7
reflex/vars/sequence.py

@@ -853,31 +853,31 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
     @overload
     def __getitem__(
         self: (
-            ArrayVar[Tuple[OTHER_TUPLE, int]]
-            | ArrayVar[Tuple[OTHER_TUPLE, float]]
-            | ArrayVar[Tuple[OTHER_TUPLE, int | float]]
+            ArrayVar[Tuple[Any, int]]
+            | ArrayVar[Tuple[Any, float]]
+            | ArrayVar[Tuple[Any, int | float]]
         ),
         i: Literal[1, -1],
     ) -> NumberVar: ...
 
     @overload
     def __getitem__(
-        self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2]
+        self: ArrayVar[Tuple[str, Any]], i: Literal[0, -2]
     ) -> StringVar: ...
 
     @overload
     def __getitem__(
-        self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1]
+        self: ArrayVar[Tuple[Any, str]], i: Literal[1, -1]
     ) -> StringVar: ...
 
     @overload
     def __getitem__(
-        self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2]
+        self: ArrayVar[Tuple[bool, Any]], i: Literal[0, -2]
     ) -> BooleanVar: ...
 
     @overload
     def __getitem__(
-        self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1]
+        self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
     ) -> BooleanVar: ...
 
     @overload

+ 7 - 7
tests/units/test_state.py

@@ -45,7 +45,7 @@ from reflex.testing import chdir
 from reflex.utils import format, prerequisites, types
 from reflex.utils.exceptions import SetUndefinedStateVarError
 from reflex.utils.format import json_dumps
-from reflex.vars.base import ComputedVar, Var
+from reflex.vars.base import Var, computed_var
 from tests.units.states.mutation import MutableSQLAModel, MutableTestState
 
 from .states import GenState
@@ -109,7 +109,7 @@ class TestState(BaseState):
     _backend: int = 0
     asynctest: int = 0
 
-    @ComputedVar
+    @computed_var
     def sum(self) -> float:
         """Dynamically sum the numbers.
 
@@ -118,7 +118,7 @@ class TestState(BaseState):
         """
         return self.num1 + self.num2
 
-    @ComputedVar
+    @computed_var
     def upper(self) -> str:
         """Uppercase the key.
 
@@ -1124,7 +1124,7 @@ def test_child_state():
         v: int = 2
 
     class ChildState(MainState):
-        @ComputedVar
+        @computed_var
         def rendered_var(self):
             return self.v
 
@@ -1143,7 +1143,7 @@ def test_conditional_computed_vars():
         t1: str = "a"
         t2: str = "b"
 
-        @ComputedVar
+        @computed_var
         def rendered_var(self) -> str:
             if self.flag:
                 return self.t1
@@ -3095,12 +3095,12 @@ def test_potentially_dirty_substates():
     """
 
     class State(RxState):
-        @ComputedVar
+        @computed_var
         def foo(self) -> str:
             return ""
 
     class C1(State):
-        @ComputedVar
+        @computed_var
         def bar(self) -> str:
             return ""
 

+ 102 - 0
tests/units/vars/test_object.py

@@ -0,0 +1,102 @@
+import pytest
+from typing_extensions import assert_type
+
+import reflex as rx
+from reflex.utils.types import GenericType
+from reflex.vars.base import Var
+from reflex.vars.object import LiteralObjectVar, ObjectVar
+
+
+class Bare:
+    """A bare class with a single attribute."""
+
+    quantity: int = 0
+
+
+@rx.serializer
+def serialize_bare(obj: Bare) -> dict:
+    """A serializer for the bare class.
+
+    Args:
+        obj: The object to serialize.
+
+    Returns:
+        A dictionary with the quantity attribute.
+    """
+    return {"quantity": obj.quantity}
+
+
+class Base(rx.Base):
+    """A reflex base class with a single attribute."""
+
+    quantity: int = 0
+
+
+class ObjectState(rx.State):
+    """A reflex state with bare and base objects."""
+
+    bare: rx.Field[Bare] = rx.field(Bare())
+    base: rx.Field[Base] = rx.field(Base())
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_var_create(type_: GenericType) -> None:
+    my_object = type_()
+    var = Var.create(my_object)
+    assert var._var_type is type_
+
+    quantity = var.quantity
+    assert quantity._var_type is int
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_literal_create(type_: GenericType) -> None:
+    my_object = type_()
+    var = LiteralObjectVar.create(my_object)
+    assert var._var_type is type_
+
+    quantity = var.quantity
+    assert quantity._var_type is int
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_guess(type_: GenericType) -> None:
+    my_object = type_()
+    var = Var.create(my_object)
+    var = var.guess_type()
+    assert var._var_type is type_
+
+    quantity = var.quantity
+    assert quantity._var_type is int
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_state(type_: GenericType) -> None:
+    attr_name = type_.__name__.lower()
+    var = getattr(ObjectState, attr_name)
+    assert var._var_type is type_
+
+    quantity = var.quantity
+    assert quantity._var_type is int
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_state_to_operation(type_: GenericType) -> None:
+    attr_name = type_.__name__.lower()
+    original_var = getattr(ObjectState, attr_name)
+
+    var = original_var.to(ObjectVar, type_)
+    assert var._var_type is type_
+
+    var = original_var.to(ObjectVar)
+    assert var._var_type is type_
+
+
+def test_typing() -> None:
+    # Bare
+    var = ObjectState.bare.to(ObjectVar)
+    _ = assert_type(var, ObjectVar[Bare])
+
+    # Base
+    var = ObjectState.base
+    _ = assert_type(var, ObjectVar[Base])