test_var.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. import typing
  2. from typing import Dict, List, Set, Tuple
  3. import cloudpickle
  4. import pytest
  5. from pandas import DataFrame
  6. from reflex.base import Base
  7. from reflex.state import State
  8. from reflex.vars import (
  9. BaseVar,
  10. ComputedVar,
  11. ImportVar,
  12. ReflexDict,
  13. ReflexList,
  14. ReflexSet,
  15. Var,
  16. get_local_storage,
  17. )
  18. test_vars = [
  19. BaseVar(name="prop1", type_=int),
  20. BaseVar(name="key", type_=str),
  21. BaseVar(name="value", type_=str, state="state"),
  22. BaseVar(name="local", type_=str, state="state", is_local=True),
  23. BaseVar(name="local2", type_=str, is_local=True),
  24. ]
  25. test_import_vars = [ImportVar(tag="DataGrid"), ImportVar(tag="DataGrid", alias="Grid")]
  26. class BaseState(State):
  27. """A Test State."""
  28. val: str = "key"
  29. @pytest.fixture
  30. def TestObj():
  31. class TestObj(Base):
  32. foo: int
  33. bar: str
  34. return TestObj
  35. @pytest.fixture
  36. def ParentState(TestObj):
  37. class ParentState(State):
  38. foo: int
  39. bar: int
  40. @ComputedVar
  41. def var_without_annotation(self):
  42. return TestObj
  43. return ParentState
  44. @pytest.fixture
  45. def ChildState(ParentState, TestObj):
  46. class ChildState(ParentState):
  47. @ComputedVar
  48. def var_without_annotation(self):
  49. return TestObj
  50. return ChildState
  51. @pytest.fixture
  52. def GrandChildState(ChildState, TestObj):
  53. class GrandChildState(ChildState):
  54. @ComputedVar
  55. def var_without_annotation(self):
  56. return TestObj
  57. return GrandChildState
  58. @pytest.fixture
  59. def StateWithAnyVar(TestObj):
  60. class StateWithAnyVar(State):
  61. @ComputedVar
  62. def var_without_annotation(self) -> typing.Any:
  63. return TestObj
  64. return StateWithAnyVar
  65. @pytest.fixture
  66. def StateWithCorrectVarAnnotation():
  67. class StateWithCorrectVarAnnotation(State):
  68. @ComputedVar
  69. def var_with_annotation(self) -> str:
  70. return "Correct annotation"
  71. return StateWithCorrectVarAnnotation
  72. @pytest.fixture
  73. def StateWithWrongVarAnnotation(TestObj):
  74. class StateWithWrongVarAnnotation(State):
  75. @ComputedVar
  76. def var_with_annotation(self) -> str:
  77. return TestObj
  78. return StateWithWrongVarAnnotation
  79. @pytest.mark.parametrize(
  80. "prop,expected",
  81. zip(
  82. test_vars,
  83. [
  84. "prop1",
  85. "key",
  86. "state.value",
  87. "state.local",
  88. "local2",
  89. ],
  90. ),
  91. )
  92. def test_full_name(prop, expected):
  93. """Test that the full name of a var is correct.
  94. Args:
  95. prop: The var to test.
  96. expected: The expected full name.
  97. """
  98. assert prop.full_name == expected
  99. @pytest.mark.parametrize(
  100. "prop,expected",
  101. zip(
  102. test_vars,
  103. ["{prop1}", "{key}", "{state.value}", "state.local", "local2"],
  104. ),
  105. )
  106. def test_str(prop, expected):
  107. """Test that the string representation of a var is correct.
  108. Args:
  109. prop: The var to test.
  110. expected: The expected string representation.
  111. """
  112. assert str(prop) == expected
  113. @pytest.mark.parametrize(
  114. "prop,expected",
  115. [
  116. (BaseVar(name="p", type_=int), 0),
  117. (BaseVar(name="p", type_=float), 0.0),
  118. (BaseVar(name="p", type_=str), ""),
  119. (BaseVar(name="p", type_=bool), False),
  120. (BaseVar(name="p", type_=list), []),
  121. (BaseVar(name="p", type_=dict), {}),
  122. (BaseVar(name="p", type_=tuple), ()),
  123. (BaseVar(name="p", type_=set), set()),
  124. ],
  125. )
  126. def test_default_value(prop, expected):
  127. """Test that the default value of a var is correct.
  128. Args:
  129. prop: The var to test.
  130. expected: The expected default value.
  131. """
  132. assert prop.get_default_value() == expected
  133. @pytest.mark.parametrize(
  134. "prop,expected",
  135. zip(
  136. test_vars,
  137. [
  138. "set_prop1",
  139. "set_key",
  140. "state.set_value",
  141. "state.set_local",
  142. "set_local2",
  143. ],
  144. ),
  145. )
  146. def test_get_setter(prop, expected):
  147. """Test that the name of the setter function of a var is correct.
  148. Args:
  149. prop: The var to test.
  150. expected: The expected name of the setter function.
  151. """
  152. assert prop.get_setter_name() == expected
  153. @pytest.mark.parametrize(
  154. "value,expected",
  155. [
  156. (None, None),
  157. (1, BaseVar(name="1", type_=int, is_local=True)),
  158. ("key", BaseVar(name="key", type_=str, is_local=True)),
  159. (3.14, BaseVar(name="3.14", type_=float, is_local=True)),
  160. ([1, 2, 3], BaseVar(name="[1, 2, 3]", type_=list, is_local=True)),
  161. (
  162. {"a": 1, "b": 2},
  163. BaseVar(name='{"a": 1, "b": 2}', type_=dict, is_local=True),
  164. ),
  165. ],
  166. )
  167. def test_create(value, expected):
  168. """Test the var create function.
  169. Args:
  170. value: The value to create a var from.
  171. expected: The expected name of the setter function.
  172. """
  173. prop = Var.create(value)
  174. if value is None:
  175. assert prop == expected
  176. else:
  177. assert prop.equals(expected) # type: ignore
  178. def test_create_type_error():
  179. """Test the var create function when inputs type error."""
  180. class ErrorType:
  181. pass
  182. value = ErrorType()
  183. with pytest.raises(TypeError) as exception:
  184. Var.create(value)
  185. assert (
  186. exception.value.args[0]
  187. == f"To create a Var must be Var or JSON-serializable. Got {value} of type {type(value)}."
  188. )
  189. def v(value) -> Var:
  190. val = Var.create(value)
  191. assert val is not None
  192. return val
  193. def test_basic_operations(TestObj):
  194. """Test the var operations.
  195. Args:
  196. TestObj: The test object.
  197. """
  198. assert str(v(1) == v(2)) == "{(1 === 2)}"
  199. assert str(v(1) != v(2)) == "{(1 !== 2)}"
  200. assert str(v(1) < v(2)) == "{(1 < 2)}"
  201. assert str(v(1) <= v(2)) == "{(1 <= 2)}"
  202. assert str(v(1) > v(2)) == "{(1 > 2)}"
  203. assert str(v(1) >= v(2)) == "{(1 >= 2)}"
  204. assert str(v(1) + v(2)) == "{(1 + 2)}"
  205. assert str(v(1) - v(2)) == "{(1 - 2)}"
  206. assert str(v(1) * v(2)) == "{(1 * 2)}"
  207. assert str(v(1) / v(2)) == "{(1 / 2)}"
  208. assert str(v(1) // v(2)) == "{Math.floor(1 / 2)}"
  209. assert str(v(1) % v(2)) == "{(1 % 2)}"
  210. assert str(v(1) ** v(2)) == "{Math.pow(1 , 2)}"
  211. assert str(v(1) & v(2)) == "{(1 && 2)}"
  212. assert str(v(1) | v(2)) == "{(1 || 2)}"
  213. assert str(v([1, 2, 3])[v(0)]) == "{[1, 2, 3].at(0)}"
  214. assert str(v({"a": 1, "b": 2})["a"]) == '{{"a": 1, "b": 2}["a"]}'
  215. assert (
  216. str(BaseVar(name="foo", state="state", type_=TestObj).bar) == "{state.foo.bar}"
  217. )
  218. assert str(abs(v(1))) == "{Math.abs(1)}"
  219. assert str(v([1, 2, 3]).length()) == "{[1, 2, 3].length}"
  220. @pytest.mark.parametrize(
  221. "var",
  222. [
  223. BaseVar(name="list", type_=List[int]),
  224. BaseVar(name="tuple", type_=Tuple[int, int]),
  225. BaseVar(name="str", type_=str),
  226. ],
  227. )
  228. def test_var_indexing_lists(var):
  229. """Test that we can index into str, list or tuple vars.
  230. Args:
  231. var : The str, list or tuple base var.
  232. """
  233. # Test basic indexing.
  234. assert str(var[0]) == f"{{{var.name}.at(0)}}"
  235. assert str(var[1]) == f"{{{var.name}.at(1)}}"
  236. # Test negative indexing.
  237. assert str(var[-1]) == f"{{{var.name}.at(-1)}}"
  238. @pytest.mark.parametrize(
  239. "var, index",
  240. [
  241. (BaseVar(name="lst", type_=List[int]), [1, 2]),
  242. (BaseVar(name="lst", type_=List[int]), {"name": "dict"}),
  243. (BaseVar(name="lst", type_=List[int]), {"set"}),
  244. (
  245. BaseVar(name="lst", type_=List[int]),
  246. (
  247. 1,
  248. 2,
  249. ),
  250. ),
  251. (BaseVar(name="lst", type_=List[int]), 1.5),
  252. (BaseVar(name="lst", type_=List[int]), "str"),
  253. (BaseVar(name="lst", type_=List[int]), BaseVar(name="string_var", type_=str)),
  254. (BaseVar(name="lst", type_=List[int]), BaseVar(name="float_var", type_=float)),
  255. (
  256. BaseVar(name="lst", type_=List[int]),
  257. BaseVar(name="list_var", type_=List[int]),
  258. ),
  259. (BaseVar(name="lst", type_=List[int]), BaseVar(name="set_var", type_=Set[str])),
  260. (
  261. BaseVar(name="lst", type_=List[int]),
  262. BaseVar(name="dict_var", type_=Dict[str, str]),
  263. ),
  264. (BaseVar(name="str", type_=str), [1, 2]),
  265. (BaseVar(name="lst", type_=str), {"name": "dict"}),
  266. (BaseVar(name="lst", type_=str), {"set"}),
  267. (BaseVar(name="lst", type_=str), BaseVar(name="string_var", type_=str)),
  268. (BaseVar(name="lst", type_=str), BaseVar(name="float_var", type_=float)),
  269. (BaseVar(name="str", type_=Tuple[str]), [1, 2]),
  270. (BaseVar(name="lst", type_=Tuple[str]), {"name": "dict"}),
  271. (BaseVar(name="lst", type_=Tuple[str]), {"set"}),
  272. (BaseVar(name="lst", type_=Tuple[str]), BaseVar(name="string_var", type_=str)),
  273. (BaseVar(name="lst", type_=Tuple[str]), BaseVar(name="float_var", type_=float)),
  274. ],
  275. )
  276. def test_var_unsupported_indexing_lists(var, index):
  277. """Test unsupported indexing throws a type error.
  278. Args:
  279. var: The base var.
  280. index: The base var index.
  281. """
  282. with pytest.raises(TypeError):
  283. var[index]
  284. @pytest.mark.parametrize(
  285. "var",
  286. [
  287. BaseVar(name="lst", type_=List[int]),
  288. BaseVar(name="tuple", type_=Tuple[int, int]),
  289. BaseVar(name="str", type_=str),
  290. ],
  291. )
  292. def test_var_list_slicing(var):
  293. """Test that we can slice into str, list or tuple vars.
  294. Args:
  295. var : The str, list or tuple base var.
  296. """
  297. assert str(var[:1]) == f"{{{var.name}.slice(0, 1)}}"
  298. assert str(var[:1]) == f"{{{var.name}.slice(0, 1)}}"
  299. assert str(var[:]) == f"{{{var.name}.slice(0, undefined)}}"
  300. def test_dict_indexing():
  301. """Test that we can index into dict vars."""
  302. dct = BaseVar(name="dct", type_=Dict[str, int])
  303. # Check correct indexing.
  304. assert str(dct["a"]) == '{dct["a"]}'
  305. assert str(dct["asdf"]) == '{dct["asdf"]}'
  306. @pytest.mark.parametrize(
  307. "var, index",
  308. [
  309. (
  310. BaseVar(name="dict", type_=Dict[str, str]),
  311. [1, 2],
  312. ),
  313. (
  314. BaseVar(name="dict", type_=Dict[str, str]),
  315. {"name": "dict"},
  316. ),
  317. (
  318. BaseVar(name="dict", type_=Dict[str, str]),
  319. {"set"},
  320. ),
  321. (
  322. BaseVar(name="dict", type_=Dict[str, str]),
  323. (
  324. 1,
  325. 2,
  326. ),
  327. ),
  328. (
  329. BaseVar(name="lst", type_=Dict[str, str]),
  330. BaseVar(name="list_var", type_=List[int]),
  331. ),
  332. (
  333. BaseVar(name="lst", type_=Dict[str, str]),
  334. BaseVar(name="set_var", type_=Set[str]),
  335. ),
  336. (
  337. BaseVar(name="lst", type_=Dict[str, str]),
  338. BaseVar(name="dict_var", type_=Dict[str, str]),
  339. ),
  340. (
  341. BaseVar(name="df", type_=DataFrame),
  342. [1, 2],
  343. ),
  344. (
  345. BaseVar(name="df", type_=DataFrame),
  346. {"name": "dict"},
  347. ),
  348. (
  349. BaseVar(name="df", type_=DataFrame),
  350. {"set"},
  351. ),
  352. (
  353. BaseVar(name="df", type_=DataFrame),
  354. (
  355. 1,
  356. 2,
  357. ),
  358. ),
  359. (
  360. BaseVar(name="df", type_=DataFrame),
  361. BaseVar(name="list_var", type_=List[int]),
  362. ),
  363. (
  364. BaseVar(name="df", type_=DataFrame),
  365. BaseVar(name="set_var", type_=Set[str]),
  366. ),
  367. (
  368. BaseVar(name="df", type_=DataFrame),
  369. BaseVar(name="dict_var", type_=Dict[str, str]),
  370. ),
  371. ],
  372. )
  373. def test_var_unsupported_indexing_dicts(var, index):
  374. """Test unsupported indexing throws a type error.
  375. Args:
  376. var: The base var.
  377. index: The base var index.
  378. """
  379. with pytest.raises(TypeError):
  380. var[index]
  381. @pytest.mark.parametrize(
  382. "fixture,full_name",
  383. [
  384. ("ParentState", "parent_state.var_without_annotation"),
  385. ("ChildState", "parent_state.child_state.var_without_annotation"),
  386. (
  387. "GrandChildState",
  388. "parent_state.child_state.grand_child_state.var_without_annotation",
  389. ),
  390. ("StateWithAnyVar", "state_with_any_var.var_without_annotation"),
  391. ],
  392. )
  393. def test_computed_var_without_annotation_error(request, fixture, full_name):
  394. """Test that a type error is thrown when an attribute of a computed var is
  395. accessed without annotating the computed var.
  396. Args:
  397. request: Fixture Request.
  398. fixture: The state fixture.
  399. full_name: The full name of the state var.
  400. """
  401. with pytest.raises(TypeError) as err:
  402. state = request.getfixturevalue(fixture)
  403. state.var_without_annotation.foo
  404. assert (
  405. err.value.args[0]
  406. == f"You must provide an annotation for the state var `{full_name}`. Annotation cannot be `typing.Any`"
  407. )
  408. @pytest.mark.parametrize(
  409. "fixture,full_name",
  410. [
  411. (
  412. "StateWithCorrectVarAnnotation",
  413. "state_with_correct_var_annotation.var_with_annotation",
  414. ),
  415. (
  416. "StateWithWrongVarAnnotation",
  417. "state_with_wrong_var_annotation.var_with_annotation",
  418. ),
  419. ],
  420. )
  421. def test_computed_var_with_annotation_error(request, fixture, full_name):
  422. """Test that an Attribute error is thrown when a non-existent attribute of an annotated computed var is
  423. accessed or when the wrong annotation is provided to a computed var.
  424. Args:
  425. request: Fixture Request.
  426. fixture: The state fixture.
  427. full_name: The full name of the state var.
  428. """
  429. with pytest.raises(AttributeError) as err:
  430. state = request.getfixturevalue(fixture)
  431. state.var_with_annotation.foo
  432. assert (
  433. err.value.args[0]
  434. == f"The State var `{full_name}` has no attribute 'foo' or may have been annotated wrongly.\n"
  435. f"original message: 'ComputedVar' object has no attribute 'foo'"
  436. )
  437. def test_pickleable_rx_list():
  438. """Test that ReflexList is pickleable."""
  439. rx_list = ReflexList(
  440. original_list=[1, 2, 3], reassign_field=lambda x: x, field_name="random"
  441. )
  442. pickled_list = cloudpickle.dumps(rx_list)
  443. assert cloudpickle.loads(pickled_list) == rx_list
  444. def test_pickleable_rx_dict():
  445. """Test that ReflexDict is pickleable."""
  446. rx_dict = ReflexDict(
  447. original_dict={1: 2, 3: 4}, reassign_field=lambda x: x, field_name="random"
  448. )
  449. pickled_dict = cloudpickle.dumps(rx_dict)
  450. assert cloudpickle.loads(pickled_dict) == rx_dict
  451. def test_pickleable_rx_set():
  452. """Test that ReflexSet is pickleable."""
  453. rx_set = ReflexSet(
  454. original_set={1, 2, 3}, reassign_field=lambda x: x, field_name="random"
  455. )
  456. pickled_set = cloudpickle.dumps(rx_set)
  457. assert cloudpickle.loads(pickled_set) == rx_set
  458. @pytest.mark.parametrize(
  459. "import_var,expected",
  460. zip(
  461. test_import_vars,
  462. [
  463. "DataGrid",
  464. "DataGrid as Grid",
  465. ],
  466. ),
  467. )
  468. def test_import_var(import_var, expected):
  469. """Test that the import var name is computed correctly.
  470. Args:
  471. import_var: The import var.
  472. expected: expected name
  473. """
  474. assert import_var.name == expected
  475. @pytest.mark.parametrize(
  476. "key, expected",
  477. [
  478. ("test_key", BaseVar(name="localStorage.getItem('test_key')", type_=str)),
  479. (
  480. BaseVar(name="key_var", type_=str),
  481. BaseVar(name="localStorage.getItem(key_var)", type_=str),
  482. ),
  483. (
  484. BaseState.val,
  485. BaseVar(name="localStorage.getItem(base_state.val)", type_=str),
  486. ),
  487. (None, BaseVar(name="getAllLocalStorageItems()", type_=Dict)),
  488. ],
  489. )
  490. def test_get_local_storage(key, expected):
  491. """Test that the right BaseVar is return when get_local_storage is called.
  492. Args:
  493. key: Local storage key.
  494. expected: expected BaseVar.
  495. """
  496. local_storage = get_local_storage(key)
  497. assert local_storage.name == expected.name
  498. assert local_storage.type_ == expected.type_
  499. @pytest.mark.parametrize(
  500. "key",
  501. [
  502. ["list", "values"],
  503. {"name": "dict"},
  504. 10,
  505. BaseVar(name="key_var", type_=List),
  506. BaseVar(name="key_var", type_=Dict[str, str]),
  507. ],
  508. )
  509. def test_get_local_storage_raise_error(key):
  510. """Test that a type error is thrown when the wrong key type is provided.
  511. Args:
  512. key: Local storage key.
  513. """
  514. with pytest.raises(TypeError) as err:
  515. get_local_storage(key)
  516. type_ = type(key) if not isinstance(key, Var) else key.type_
  517. assert (
  518. err.value.args[0]
  519. == f"Local storage keys can only be of type `str` or `var` of type `str`. Got `{type_}` instead."
  520. )