|
@@ -1,10 +1,14 @@
|
|
|
|
+import dataclasses
|
|
|
|
+
|
|
import pytest
|
|
import pytest
|
|
|
|
+from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
|
|
from typing_extensions import assert_type
|
|
from typing_extensions import assert_type
|
|
|
|
|
|
import reflex as rx
|
|
import reflex as rx
|
|
from reflex.utils.types import GenericType
|
|
from reflex.utils.types import GenericType
|
|
from reflex.vars.base import Var
|
|
from reflex.vars.base import Var
|
|
from reflex.vars.object import LiteralObjectVar, ObjectVar
|
|
from reflex.vars.object import LiteralObjectVar, ObjectVar
|
|
|
|
+from reflex.vars.sequence import ArrayVar
|
|
|
|
|
|
|
|
|
|
class Bare:
|
|
class Bare:
|
|
@@ -32,14 +36,44 @@ class Base(rx.Base):
|
|
quantity: int = 0
|
|
quantity: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
+class SqlaBase(DeclarativeBase, MappedAsDataclass):
|
|
|
|
+ """Sqlalchemy declarative mapping base class."""
|
|
|
|
+
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class SqlaModel(SqlaBase):
|
|
|
|
+ """A sqlalchemy model with a single attribute."""
|
|
|
|
+
|
|
|
|
+ __tablename__: str = "sqla_model"
|
|
|
|
+
|
|
|
|
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
|
|
|
|
+ quantity: Mapped[int] = mapped_column(default=0)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@dataclasses.dataclass
|
|
|
|
+class Dataclass:
|
|
|
|
+ """A dataclass with a single attribute."""
|
|
|
|
+
|
|
|
|
+ quantity: int = 0
|
|
|
|
+
|
|
|
|
+
|
|
class ObjectState(rx.State):
|
|
class ObjectState(rx.State):
|
|
- """A reflex state with bare and base objects."""
|
|
|
|
|
|
+ """A reflex state with bare, base and sqlalchemy base vars."""
|
|
|
|
|
|
bare: rx.Field[Bare] = rx.field(Bare())
|
|
bare: rx.Field[Bare] = rx.field(Bare())
|
|
|
|
+ bare_optional: rx.Field[Bare | None] = rx.field(None)
|
|
base: rx.Field[Base] = rx.field(Base())
|
|
base: rx.Field[Base] = rx.field(Base())
|
|
|
|
+ base_optional: rx.Field[Base | None] = rx.field(None)
|
|
|
|
+ sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
|
|
|
|
+ sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
|
|
|
|
+ dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
|
|
|
|
+ dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
|
|
|
|
+
|
|
|
|
+ base_list: rx.Field[list[Base]] = rx.field([Base()])
|
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize("type_", [Base, Bare])
|
|
|
|
|
|
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
|
def test_var_create(type_: GenericType) -> None:
|
|
def test_var_create(type_: GenericType) -> None:
|
|
my_object = type_()
|
|
my_object = type_()
|
|
var = Var.create(my_object)
|
|
var = Var.create(my_object)
|
|
@@ -49,7 +83,7 @@ def test_var_create(type_: GenericType) -> None:
|
|
assert quantity._var_type is int
|
|
assert quantity._var_type is int
|
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize("type_", [Base, Bare])
|
|
|
|
|
|
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
|
def test_literal_create(type_: GenericType) -> None:
|
|
def test_literal_create(type_: GenericType) -> None:
|
|
my_object = type_()
|
|
my_object = type_()
|
|
var = LiteralObjectVar.create(my_object)
|
|
var = LiteralObjectVar.create(my_object)
|
|
@@ -59,7 +93,7 @@ def test_literal_create(type_: GenericType) -> None:
|
|
assert quantity._var_type is int
|
|
assert quantity._var_type is int
|
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize("type_", [Base, Bare])
|
|
|
|
|
|
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
|
def test_guess(type_: GenericType) -> None:
|
|
def test_guess(type_: GenericType) -> None:
|
|
my_object = type_()
|
|
my_object = type_()
|
|
var = Var.create(my_object)
|
|
var = Var.create(my_object)
|
|
@@ -70,7 +104,7 @@ def test_guess(type_: GenericType) -> None:
|
|
assert quantity._var_type is int
|
|
assert quantity._var_type is int
|
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize("type_", [Base, Bare])
|
|
|
|
|
|
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
|
def test_state(type_: GenericType) -> None:
|
|
def test_state(type_: GenericType) -> None:
|
|
attr_name = type_.__name__.lower()
|
|
attr_name = type_.__name__.lower()
|
|
var = getattr(ObjectState, attr_name)
|
|
var = getattr(ObjectState, attr_name)
|
|
@@ -80,7 +114,7 @@ def test_state(type_: GenericType) -> None:
|
|
assert quantity._var_type is int
|
|
assert quantity._var_type is int
|
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize("type_", [Base, Bare])
|
|
|
|
|
|
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
|
def test_state_to_operation(type_: GenericType) -> None:
|
|
def test_state_to_operation(type_: GenericType) -> None:
|
|
attr_name = type_.__name__.lower()
|
|
attr_name = type_.__name__.lower()
|
|
original_var = getattr(ObjectState, attr_name)
|
|
original_var = getattr(ObjectState, attr_name)
|
|
@@ -100,3 +134,29 @@ def test_typing() -> None:
|
|
# Base
|
|
# Base
|
|
var = ObjectState.base
|
|
var = ObjectState.base
|
|
_ = assert_type(var, ObjectVar[Base])
|
|
_ = assert_type(var, ObjectVar[Base])
|
|
|
|
+ optional_var = ObjectState.base_optional
|
|
|
|
+ _ = assert_type(optional_var, ObjectVar[Base | None])
|
|
|
|
+ list_var = ObjectState.base_list
|
|
|
|
+ _ = assert_type(list_var, ArrayVar[list[Base]])
|
|
|
|
+ list_var_0 = list_var[0]
|
|
|
|
+ _ = assert_type(list_var_0, ObjectVar[Base])
|
|
|
|
+
|
|
|
|
+ # Sqla
|
|
|
|
+ var = ObjectState.sqlamodel
|
|
|
|
+ _ = assert_type(var, ObjectVar[SqlaModel])
|
|
|
|
+ optional_var = ObjectState.sqlamodel_optional
|
|
|
|
+ _ = assert_type(optional_var, ObjectVar[SqlaModel | None])
|
|
|
|
+ list_var = ObjectState.base_list
|
|
|
|
+ _ = assert_type(list_var, ArrayVar[list[Base]])
|
|
|
|
+ list_var_0 = list_var[0]
|
|
|
|
+ _ = assert_type(list_var_0, ObjectVar[Base])
|
|
|
|
+
|
|
|
|
+ # Dataclass
|
|
|
|
+ var = ObjectState.dataclass
|
|
|
|
+ _ = assert_type(var, ObjectVar[Dataclass])
|
|
|
|
+ optional_var = ObjectState.dataclass_optional
|
|
|
|
+ _ = assert_type(optional_var, ObjectVar[Dataclass | None])
|
|
|
|
+ list_var = ObjectState.base_list
|
|
|
|
+ _ = assert_type(list_var, ArrayVar[list[Base]])
|
|
|
|
+ list_var_0 = list_var[0]
|
|
|
|
+ _ = assert_type(list_var_0, ObjectVar[Base])
|