Bladeren bron

improve rx.Field ObjectVar typing for sqlalchemy and dataclasses (#4728)

* improve rx.Field ObjectVar typing for sqlalchemy and dataclasses

* enable parametrized objectvar tests for sqlamodel and dataclass

* improve typing for ObjectVars in ArrayVars

* ruffing

* drop duplicate objectvar import

* remove redundant overload

* allow optional hints in rx.Field annotations to resolve to the correct var type
benedikt-bartscher 3 maanden geleden
bovenliggende
commit
2b7e4d6b4e
3 gewijzigde bestanden met toevoegingen van 111 en 15 verwijderingen
  1. 24 5
      reflex/vars/base.py
  2. 21 4
      reflex/vars/sequence.py
  3. 66 6
      tests/units/vars/test_object.py

+ 24 - 5
reflex/vars/base.py

@@ -40,6 +40,7 @@ from typing import (
     overload,
 )
 
+from sqlalchemy.orm import DeclarativeBase
 from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
 
 from reflex import constants
@@ -573,7 +574,7 @@ class Var(Generic[VAR_TYPE]):
 
     @overload
     @classmethod
-    def create(  # type: ignore[override]
+    def create(  # pyright: ignore[reportOverlappingOverload]
         cls,
         value: bool,
         _var_data: VarData | None = None,
@@ -581,7 +582,7 @@ class Var(Generic[VAR_TYPE]):
 
     @overload
     @classmethod
-    def create(  # type: ignore[override]
+    def create(
         cls,
         value: int,
         _var_data: VarData | None = None,
@@ -605,7 +606,7 @@ class Var(Generic[VAR_TYPE]):
 
     @overload
     @classmethod
-    def create(
+    def create(  # pyright: ignore[reportOverlappingOverload]
         cls,
         value: None,
         _var_data: VarData | None = None,
@@ -3182,10 +3183,16 @@ def dispatch(
 
 V = TypeVar("V")
 
-BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
+BASE_TYPE = TypeVar("BASE_TYPE", bound=Base | None)
+SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None)
+
+if TYPE_CHECKING:
+    from _typeshed import DataclassInstance
+
+    DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None)
 
 FIELD_TYPE = TypeVar("FIELD_TYPE")
-MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
+MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None)
 
 
 class Field(Generic[FIELD_TYPE]):
@@ -3230,6 +3237,18 @@ class Field(Generic[FIELD_TYPE]):
         self: Field[BASE_TYPE], instance: None, owner: Any
     ) -> ObjectVar[BASE_TYPE]: ...
 
+    @overload
+    def __get__(
+        self: Field[SQLA_TYPE], instance: None, owner: Any
+    ) -> ObjectVar[SQLA_TYPE]: ...
+
+    if TYPE_CHECKING:
+
+        @overload
+        def __get__(
+            self: Field[DATACLASS_TYPE], instance: None, owner: Any
+        ) -> ObjectVar[DATACLASS_TYPE]: ...
+
     @overload
     def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
 

+ 21 - 4
reflex/vars/sequence.py

@@ -53,8 +53,11 @@ from .number import (
 )
 
 if TYPE_CHECKING:
+    from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE
+    from .function import FunctionVar
     from .object import ObjectVar
 
+
 STRING_TYPE = TypeVar("STRING_TYPE", default=str)
 
 
@@ -961,6 +964,24 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
         i: int | NumberVar,
     ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
 
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE],
+        i: int | NumberVar,
+    ) -> ObjectVar[BASE_TYPE]: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE],
+        i: int | NumberVar,
+    ) -> ObjectVar[SQLA_TYPE]: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE],
+        i: int | NumberVar,
+    ) -> ObjectVar[DATACLASS_TYPE]: ...
+
     @overload
     def __getitem__(self, i: int | NumberVar) -> Var: ...
 
@@ -1648,10 +1669,6 @@ def repeat_array_operation(
     )
 
 
-if TYPE_CHECKING:
-    from .function import FunctionVar
-
-
 @var_operation
 def map_array_operation(
     array: ArrayVar[ARRAY_VAR_TYPE],

+ 66 - 6
tests/units/vars/test_object.py

@@ -1,10 +1,14 @@
+import dataclasses
+
 import pytest
+from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
 from typing_extensions import assert_type
 
 import reflex as rx
 from reflex.utils.types import GenericType
 from reflex.vars.base import Var
 from reflex.vars.object import LiteralObjectVar, ObjectVar
+from reflex.vars.sequence import ArrayVar
 
 
 class Bare:
@@ -32,14 +36,44 @@ class Base(rx.Base):
     quantity: int = 0
 
 
+class SqlaBase(DeclarativeBase, MappedAsDataclass):
+    """Sqlalchemy declarative mapping base class."""
+
+    pass
+
+
+class SqlaModel(SqlaBase):
+    """A sqlalchemy model with a single attribute."""
+
+    __tablename__: str = "sqla_model"
+
+    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
+    quantity: Mapped[int] = mapped_column(default=0)
+
+
+@dataclasses.dataclass
+class Dataclass:
+    """A dataclass with a single attribute."""
+
+    quantity: int = 0
+
+
 class ObjectState(rx.State):
-    """A reflex state with bare and base objects."""
+    """A reflex state with bare, base and sqlalchemy base vars."""
 
     bare: rx.Field[Bare] = rx.field(Bare())
+    bare_optional: rx.Field[Bare | None] = rx.field(None)
     base: rx.Field[Base] = rx.field(Base())
+    base_optional: rx.Field[Base | None] = rx.field(None)
+    sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
+    sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
+    dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
+    dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
+
+    base_list: rx.Field[list[Base]] = rx.field([Base()])
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_var_create(type_: GenericType) -> None:
     my_object = type_()
     var = Var.create(my_object)
@@ -49,7 +83,7 @@ def test_var_create(type_: GenericType) -> None:
     assert quantity._var_type is int
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_literal_create(type_: GenericType) -> None:
     my_object = type_()
     var = LiteralObjectVar.create(my_object)
@@ -59,7 +93,7 @@ def test_literal_create(type_: GenericType) -> None:
     assert quantity._var_type is int
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_guess(type_: GenericType) -> None:
     my_object = type_()
     var = Var.create(my_object)
@@ -70,7 +104,7 @@ def test_guess(type_: GenericType) -> None:
     assert quantity._var_type is int
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_state(type_: GenericType) -> None:
     attr_name = type_.__name__.lower()
     var = getattr(ObjectState, attr_name)
@@ -80,7 +114,7 @@ def test_state(type_: GenericType) -> None:
     assert quantity._var_type is int
 
 
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
 def test_state_to_operation(type_: GenericType) -> None:
     attr_name = type_.__name__.lower()
     original_var = getattr(ObjectState, attr_name)
@@ -100,3 +134,29 @@ def test_typing() -> None:
     # Base
     var = ObjectState.base
     _ = assert_type(var, ObjectVar[Base])
+    optional_var = ObjectState.base_optional
+    _ = assert_type(optional_var, ObjectVar[Base | None])
+    list_var = ObjectState.base_list
+    _ = assert_type(list_var, ArrayVar[list[Base]])
+    list_var_0 = list_var[0]
+    _ = assert_type(list_var_0, ObjectVar[Base])
+
+    # Sqla
+    var = ObjectState.sqlamodel
+    _ = assert_type(var, ObjectVar[SqlaModel])
+    optional_var = ObjectState.sqlamodel_optional
+    _ = assert_type(optional_var, ObjectVar[SqlaModel | None])
+    list_var = ObjectState.base_list
+    _ = assert_type(list_var, ArrayVar[list[Base]])
+    list_var_0 = list_var[0]
+    _ = assert_type(list_var_0, ObjectVar[Base])
+
+    # Dataclass
+    var = ObjectState.dataclass
+    _ = assert_type(var, ObjectVar[Dataclass])
+    optional_var = ObjectState.dataclass_optional
+    _ = assert_type(optional_var, ObjectVar[Dataclass | None])
+    list_var = ObjectState.base_list
+    _ = assert_type(list_var, ArrayVar[list[Base]])
+    list_var_0 = list_var[0]
+    _ = assert_type(list_var_0, ObjectVar[Base])