Jelajahi Sumber

improve rx.Field ObjectVar typing for sqlalchemy and dataclasses (#4728)

* improve rx.Field ObjectVar typing for sqlalchemy and dataclasses

* enable parametrized objectvar tests for sqlamodel and dataclass

* improve typing for ObjectVars in ArrayVars

* ruffing

* drop duplicate objectvar import

* remove redundant overload

* allow optional hints in rx.Field annotations to resolve to the correct var type
benedikt-bartscher 3 bulan lalu
induk
melakukan
2b7e4d6b4e
3 mengubah file dengan 111 tambahan dan 15 penghapusan
  1. 24 5
      reflex/vars/base.py
  2. 21 4
      reflex/vars/sequence.py
  3. 66 6
      tests/units/vars/test_object.py

+ 24 - 5
reflex/vars/base.py

@@ -40,6 +40,7 @@ from typing import (
     overload,
     overload,
 )
 )
 
 
+from sqlalchemy.orm import DeclarativeBase
 from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
 from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
 
 
 from reflex import constants
 from reflex import constants
@@ -573,7 +574,7 @@ class Var(Generic[VAR_TYPE]):
 
 
     @overload
     @overload
     @classmethod
     @classmethod
-    def create(  # type: ignore[override]
+    def create(  # pyright: ignore[reportOverlappingOverload]
         cls,
         cls,
         value: bool,
         value: bool,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
@@ -581,7 +582,7 @@ class Var(Generic[VAR_TYPE]):
 
 
     @overload
     @overload
     @classmethod
     @classmethod
-    def create(  # type: ignore[override]
+    def create(
         cls,
         cls,
         value: int,
         value: int,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
@@ -605,7 +606,7 @@ class Var(Generic[VAR_TYPE]):
 
 
     @overload
     @overload
     @classmethod
     @classmethod
-    def create(
+    def create(  # pyright: ignore[reportOverlappingOverload]
         cls,
         cls,
         value: None,
         value: None,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
@@ -3182,10 +3183,16 @@ def dispatch(
 
 
 V = TypeVar("V")
 V = TypeVar("V")
 
 
-BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
+BASE_TYPE = TypeVar("BASE_TYPE", bound=Base | None)
+SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None)
+
+if TYPE_CHECKING:
+    from _typeshed import DataclassInstance
+
+    DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None)
 
 
 FIELD_TYPE = TypeVar("FIELD_TYPE")
 FIELD_TYPE = TypeVar("FIELD_TYPE")
-MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
+MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None)
 
 
 
 
 class Field(Generic[FIELD_TYPE]):
 class Field(Generic[FIELD_TYPE]):
@@ -3230,6 +3237,18 @@ class Field(Generic[FIELD_TYPE]):
         self: Field[BASE_TYPE], instance: None, owner: Any
         self: Field[BASE_TYPE], instance: None, owner: Any
     ) -> ObjectVar[BASE_TYPE]: ...
     ) -> ObjectVar[BASE_TYPE]: ...
 
 
+    @overload
+    def __get__(
+        self: Field[SQLA_TYPE], instance: None, owner: Any
+    ) -> ObjectVar[SQLA_TYPE]: ...
+
+    if TYPE_CHECKING:
+
+        @overload
+        def __get__(
+            self: Field[DATACLASS_TYPE], instance: None, owner: Any
+        ) -> ObjectVar[DATACLASS_TYPE]: ...
+
     @overload
     @overload
     def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
     def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
 
 

+ 21 - 4
reflex/vars/sequence.py

@@ -53,8 +53,11 @@ from .number import (
 )
 )
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
+    from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE
+    from .function import FunctionVar
     from .object import ObjectVar
     from .object import ObjectVar
 
 
+
 STRING_TYPE = TypeVar("STRING_TYPE", default=str)
 STRING_TYPE = TypeVar("STRING_TYPE", default=str)
 
 
 
 
@@ -961,6 +964,24 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
         i: int | NumberVar,
         i: int | NumberVar,
     ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
     ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
 
 
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE],
+        i: int | NumberVar,
+    ) -> ObjectVar[BASE_TYPE]: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE],
+        i: int | NumberVar,
+    ) -> ObjectVar[SQLA_TYPE]: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE],
+        i: int | NumberVar,
+    ) -> ObjectVar[DATACLASS_TYPE]: ...
+
     @overload
     @overload
     def __getitem__(self, i: int | NumberVar) -> Var: ...
     def __getitem__(self, i: int | NumberVar) -> Var: ...
 
 
@@ -1648,10 +1669,6 @@ def repeat_array_operation(
     )
     )
 
 
 
 
-if TYPE_CHECKING:
-    from .function import FunctionVar
-
-
 @var_operation
 @var_operation
 def map_array_operation(
 def map_array_operation(
     array: ArrayVar[ARRAY_VAR_TYPE],
     array: ArrayVar[ARRAY_VAR_TYPE],

+ 66 - 6
tests/units/vars/test_object.py

@@ -1,10 +1,14 @@
+import dataclasses
+
 import pytest
 import pytest
+from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
 from typing_extensions import assert_type
 from typing_extensions import assert_type
 
 
 import reflex as rx
 import reflex as rx
 from reflex.utils.types import GenericType
 from reflex.utils.types import GenericType
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 from reflex.vars.object import LiteralObjectVar, ObjectVar
 from reflex.vars.object import LiteralObjectVar, ObjectVar
+from reflex.vars.sequence import ArrayVar
 
 
 
 
 class Bare:
 class Bare:
@@ -32,14 +36,44 @@ class Base(rx.Base):
     quantity: int = 0
     quantity: int = 0
 
 
 
 
+class SqlaBase(DeclarativeBase, MappedAsDataclass):
+    """Sqlalchemy declarative mapping base class."""
+
+    pass
+
+
+class SqlaModel(SqlaBase):
+    """A sqlalchemy model with a single attribute."""
+
+    __tablename__: str = "sqla_model"
+
+    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
+    quantity: Mapped[int] = mapped_column(default=0)
+
+
+@dataclasses.dataclass
+class Dataclass:
+    """A dataclass with a single attribute."""
+
+    quantity: int = 0
+
+
 class ObjectState(rx.State):
 class ObjectState(rx.State):
-    """A reflex state with bare and base objects."""
+    """A reflex state with bare, base and sqlalchemy base vars."""
 
 
     bare: rx.Field[Bare] = rx.field(Bare())
     bare: rx.Field[Bare] = rx.field(Bare())
+    bare_optional: rx.Field[Bare | None] = rx.field(None)
     base: rx.Field[Base] = rx.field(Base())
     base: rx.Field[Base] = rx.field(Base())
+    base_optional: rx.Field[Base | None] = rx.field(None)
+    sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
+    sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
+    dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
+    dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
+
+    base_list: rx.Field[list[Base]] = rx.field([Base()])
 
 
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_var_create(type_: GenericType) -> None:
 def test_var_create(type_: GenericType) -> None:
     my_object = type_()
     my_object = type_()
     var = Var.create(my_object)
     var = Var.create(my_object)
@@ -49,7 +83,7 @@ def test_var_create(type_: GenericType) -> None:
     assert quantity._var_type is int
     assert quantity._var_type is int
 
 
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_literal_create(type_: GenericType) -> None:
 def test_literal_create(type_: GenericType) -> None:
     my_object = type_()
     my_object = type_()
     var = LiteralObjectVar.create(my_object)
     var = LiteralObjectVar.create(my_object)
@@ -59,7 +93,7 @@ def test_literal_create(type_: GenericType) -> None:
     assert quantity._var_type is int
     assert quantity._var_type is int
 
 
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_guess(type_: GenericType) -> None:
 def test_guess(type_: GenericType) -> None:
     my_object = type_()
     my_object = type_()
     var = Var.create(my_object)
     var = Var.create(my_object)
@@ -70,7 +104,7 @@ def test_guess(type_: GenericType) -> None:
     assert quantity._var_type is int
     assert quantity._var_type is int
 
 
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_state(type_: GenericType) -> None:
 def test_state(type_: GenericType) -> None:
     attr_name = type_.__name__.lower()
     attr_name = type_.__name__.lower()
     var = getattr(ObjectState, attr_name)
     var = getattr(ObjectState, attr_name)
@@ -80,7 +114,7 @@ def test_state(type_: GenericType) -> None:
     assert quantity._var_type is int
     assert quantity._var_type is int
 
 
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_state_to_operation(type_: GenericType) -> None:
 def test_state_to_operation(type_: GenericType) -> None:
     attr_name = type_.__name__.lower()
     attr_name = type_.__name__.lower()
     original_var = getattr(ObjectState, attr_name)
     original_var = getattr(ObjectState, attr_name)
@@ -100,3 +134,29 @@ def test_typing() -> None:
     # Base
     # Base
     var = ObjectState.base
     var = ObjectState.base
     _ = assert_type(var, ObjectVar[Base])
     _ = assert_type(var, ObjectVar[Base])
+    optional_var = ObjectState.base_optional
+    _ = assert_type(optional_var, ObjectVar[Base | None])
+    list_var = ObjectState.base_list
+    _ = assert_type(list_var, ArrayVar[list[Base]])
+    list_var_0 = list_var[0]
+    _ = assert_type(list_var_0, ObjectVar[Base])
+
+    # Sqla
+    var = ObjectState.sqlamodel
+    _ = assert_type(var, ObjectVar[SqlaModel])
+    optional_var = ObjectState.sqlamodel_optional
+    _ = assert_type(optional_var, ObjectVar[SqlaModel | None])
+    list_var = ObjectState.base_list
+    _ = assert_type(list_var, ArrayVar[list[Base]])
+    list_var_0 = list_var[0]
+    _ = assert_type(list_var_0, ObjectVar[Base])
+
+    # Dataclass
+    var = ObjectState.dataclass
+    _ = assert_type(var, ObjectVar[Dataclass])
+    optional_var = ObjectState.dataclass_optional
+    _ = assert_type(optional_var, ObjectVar[Dataclass | None])
+    list_var = ObjectState.base_list
+    _ = assert_type(list_var, ArrayVar[list[Base]])
+    list_var_0 = list_var[0]
+    _ = assert_type(list_var_0, ObjectVar[Base])