test_object.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import dataclasses
  2. from typing import Sequence
  3. import pytest
  4. from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
  5. from typing_extensions import assert_type
  6. import reflex as rx
  7. from reflex.utils.types import GenericType
  8. from reflex.vars.base import Var
  9. from reflex.vars.object import LiteralObjectVar, ObjectVar
  10. from reflex.vars.sequence import ArrayVar
  11. class Bare:
  12. """A bare class with a single attribute."""
  13. quantity: int = 0
  14. @rx.serializer
  15. def serialize_bare(obj: Bare) -> dict:
  16. """A serializer for the bare class.
  17. Args:
  18. obj: The object to serialize.
  19. Returns:
  20. A dictionary with the quantity attribute.
  21. """
  22. return {"quantity": obj.quantity}
  23. class Base(rx.Base):
  24. """A reflex base class with a single attribute."""
  25. quantity: int = 0
  26. class SqlaBase(DeclarativeBase, MappedAsDataclass):
  27. """Sqlalchemy declarative mapping base class."""
  28. pass
  29. class SqlaModel(SqlaBase):
  30. """A sqlalchemy model with a single attribute."""
  31. __tablename__: str = "sqla_model"
  32. id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
  33. quantity: Mapped[int] = mapped_column(default=0)
  34. @dataclasses.dataclass
  35. class Dataclass:
  36. """A dataclass with a single attribute."""
  37. quantity: int = 0
  38. class ObjectState(rx.State):
  39. """A reflex state with bare, base and sqlalchemy base vars."""
  40. bare: rx.Field[Bare] = rx.field(Bare())
  41. bare_optional: rx.Field[Bare | None] = rx.field(None)
  42. base: rx.Field[Base] = rx.field(Base())
  43. base_optional: rx.Field[Base | None] = rx.field(None)
  44. sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
  45. sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
  46. dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
  47. dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
  48. base_list: rx.Field[list[Base]] = rx.field([Base()])
  49. @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
  50. def test_var_create(type_: type[Base | Bare | SqlaModel | Dataclass]) -> None:
  51. my_object = type_()
  52. var = Var.create(my_object)
  53. assert var._var_type is type_
  54. assert isinstance(var, ObjectVar)
  55. quantity = var.quantity
  56. assert quantity._var_type is int
  57. @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
  58. def test_literal_create(type_: GenericType) -> None:
  59. my_object = type_()
  60. var = LiteralObjectVar.create(my_object)
  61. assert var._var_type is type_
  62. quantity = var.quantity
  63. assert quantity._var_type is int
  64. @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
  65. def test_guess(type_: type[Base | Bare | SqlaModel | Dataclass]) -> None:
  66. my_object = type_()
  67. var = Var.create(my_object)
  68. var = var.guess_type()
  69. assert var._var_type is type_
  70. assert isinstance(var, ObjectVar)
  71. quantity = var.quantity
  72. assert quantity._var_type is int
  73. @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
  74. def test_state(type_: GenericType) -> None:
  75. attr_name = type_.__name__.lower()
  76. var = getattr(ObjectState, attr_name)
  77. assert var._var_type is type_
  78. quantity = var.quantity
  79. assert quantity._var_type is int
  80. @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
  81. def test_state_to_operation(type_: GenericType) -> None:
  82. attr_name = type_.__name__.lower()
  83. original_var = getattr(ObjectState, attr_name)
  84. var = original_var.to(ObjectVar, type_)
  85. assert var._var_type is type_
  86. var = original_var.to(ObjectVar)
  87. assert var._var_type is type_
  88. def test_typing() -> None:
  89. # Bare
  90. var = ObjectState.bare.to(ObjectVar)
  91. _ = assert_type(var, ObjectVar[Bare])
  92. # Base
  93. var = ObjectState.base
  94. _ = assert_type(var, ObjectVar[Base])
  95. optional_var = ObjectState.base_optional
  96. _ = assert_type(optional_var, ObjectVar[Base])
  97. list_var = ObjectState.base_list
  98. _ = assert_type(list_var, ArrayVar[Sequence[Base]])
  99. list_var_0 = list_var[0]
  100. _ = assert_type(list_var_0, ObjectVar[Base])
  101. # Sqla
  102. var = ObjectState.sqlamodel
  103. _ = assert_type(var, ObjectVar[SqlaModel])
  104. optional_var = ObjectState.sqlamodel_optional
  105. _ = assert_type(optional_var, ObjectVar[SqlaModel])
  106. list_var = ObjectState.base_list
  107. _ = assert_type(list_var, ArrayVar[Sequence[Base]])
  108. list_var_0 = list_var[0]
  109. _ = assert_type(list_var_0, ObjectVar[Base])
  110. # Dataclass
  111. var = ObjectState.dataclass
  112. _ = assert_type(var, ObjectVar[Dataclass])
  113. optional_var = ObjectState.dataclass_optional
  114. _ = assert_type(optional_var, ObjectVar[Dataclass])
  115. list_var = ObjectState.base_list
  116. _ = assert_type(list_var, ArrayVar[Sequence[Base]])
  117. list_var_0 = list_var[0]
  118. _ = assert_type(list_var_0, ObjectVar[Base])