1
0

test_base.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from typing import List, Mapping, Union
  2. import pytest
  3. from reflex.state import State
  4. from reflex.vars.base import computed_var, figure_out_type
  5. class CustomDict(dict[str, str]):
  6. """A custom dict with generic arguments."""
  7. pass
  8. class ChildCustomDict(CustomDict):
  9. """A child of CustomDict."""
  10. pass
  11. class GenericDict(dict):
  12. """A generic dict with no generic arguments."""
  13. pass
  14. class ChildGenericDict(GenericDict):
  15. """A child of GenericDict."""
  16. pass
  17. @pytest.mark.parametrize(
  18. ("value", "expected"),
  19. [
  20. (1, int),
  21. (1.0, float),
  22. ("a", str),
  23. ([1, 2, 3], List[int]),
  24. ([1, 2.0, "a"], List[Union[int, float, str]]),
  25. ({"a": 1, "b": 2}, Mapping[str, int]),
  26. ({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]),
  27. (CustomDict(), CustomDict),
  28. (ChildCustomDict(), ChildCustomDict),
  29. (GenericDict({1: 1}), Mapping[int, int]),
  30. (ChildGenericDict({1: 1}), Mapping[int, int]),
  31. ],
  32. )
  33. def test_figure_out_type(value, expected):
  34. assert figure_out_type(value) == expected
  35. def test_computed_var_replace() -> None:
  36. class StateTest(State):
  37. @computed_var(cache=True)
  38. def cv(self) -> int:
  39. return 1
  40. cv = StateTest.cv
  41. assert cv._var_type is int
  42. replaced = cv._replace(_var_type=float)
  43. assert replaced._var_type is float