test_state_tree.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. """Specialized test for a larger state tree."""
  2. from typing import AsyncGenerator
  3. import pytest
  4. import pytest_asyncio
  5. import reflex as rx
  6. from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key
  7. class Root(BaseState):
  8. """Root of the state tree."""
  9. root: int
  10. class TreeA(Root):
  11. """TreeA is a child of Root."""
  12. a: int
  13. class SubA_A(TreeA):
  14. """SubA_A is a child of TreeA."""
  15. sub_a_a: int
  16. class SubA_A_A(SubA_A):
  17. """SubA_A_A is a child of SubA_A."""
  18. sub_a_a_a: int
  19. class SubA_A_A_A(SubA_A_A):
  20. """SubA_A_A_A is a child of SubA_A_A."""
  21. sub_a_a_a_a: int
  22. class SubA_A_A_B(SubA_A_A):
  23. """SubA_A_A_B is a child of SubA_A_A."""
  24. @rx.var(cache=True)
  25. def sub_a_a_a_cached(self) -> int:
  26. """A cached var.
  27. Returns:
  28. The value of sub_a_a_a + 1
  29. """
  30. return self.sub_a_a_a + 1
  31. class SubA_A_A_C(SubA_A_A):
  32. """SubA_A_A_C is a child of SubA_A_A."""
  33. sub_a_a_a_c: int
  34. class SubA_A_B(SubA_A):
  35. """SubA_A_B is a child of SubA_A."""
  36. sub_a_a_b: int
  37. class SubA_B(TreeA):
  38. """SubA_B is a child of TreeA."""
  39. sub_a_b: int
  40. class TreeB(Root):
  41. """TreeB is a child of Root."""
  42. b: int
  43. class SubB_A(TreeB):
  44. """SubB_A is a child of TreeB."""
  45. sub_b_a: int
  46. class SubB_B(TreeB):
  47. """SubB_B is a child of TreeB."""
  48. sub_b_b: int
  49. class SubB_C(TreeB):
  50. """SubB_C is a child of TreeB."""
  51. sub_b_c: int
  52. class SubB_C_A(SubB_C):
  53. """SubB_C_A is a child of SubB_C."""
  54. sub_b_c_a: int
  55. class TreeC(Root):
  56. """TreeC is a child of Root."""
  57. c: int
  58. class SubC_A(TreeC):
  59. """SubC_A is a child of TreeC."""
  60. sub_c_a: int
  61. class TreeD(Root):
  62. """TreeD is a child of Root."""
  63. d: int
  64. @rx.var
  65. def d_var(self) -> int:
  66. """A computed var.
  67. Returns:
  68. The value of d + 1
  69. """
  70. return self.d + 1
  71. class TreeE(Root):
  72. """TreeE is a child of Root."""
  73. e: int
  74. class SubE_A(TreeE):
  75. """SubE_A is a child of TreeE."""
  76. sub_e_a: int
  77. class SubE_A_A(SubE_A):
  78. """SubE_A_A is a child of SubE_A."""
  79. sub_e_a_a: int
  80. class SubE_A_A_A(SubE_A_A):
  81. """SubE_A_A_A is a child of SubE_A_A."""
  82. sub_e_a_a_a: int
  83. class SubE_A_A_A_A(SubE_A_A_A):
  84. """SubE_A_A_A_A is a child of SubE_A_A_A."""
  85. sub_e_a_a_a_a: int
  86. @rx.var
  87. def sub_e_a_a_a_a_var(self) -> int:
  88. """A computed var.
  89. Returns:
  90. The value of sub_e_a_a_a_a + 1
  91. """
  92. return self.sub_e_a_a_a + 1
  93. class SubE_A_A_A_B(SubE_A_A_A):
  94. """SubE_A_A_A_B is a child of SubE_A_A_A."""
  95. sub_e_a_a_a_b: int
  96. class SubE_A_A_A_C(SubE_A_A_A):
  97. """SubE_A_A_A_C is a child of SubE_A_A_A."""
  98. sub_e_a_a_a_c: int
  99. class SubE_A_A_A_D(SubE_A_A_A):
  100. """SubE_A_A_A_D is a child of SubE_A_A_A."""
  101. sub_e_a_a_a_d: int
  102. @rx.var(cache=True)
  103. def sub_e_a_a_a_d_var(self) -> int:
  104. """A computed var.
  105. Returns:
  106. The value of sub_e_a_a_a_a + 1
  107. """
  108. return self.sub_e_a_a_a + 1
  109. ALWAYS_COMPUTED_VARS = {
  110. TreeD.get_full_name(): {"d_var": 1},
  111. SubE_A_A_A_A.get_full_name(): {"sub_e_a_a_a_a_var": 1},
  112. }
  113. ALWAYS_COMPUTED_DICT_KEYS = [
  114. Root.get_full_name(),
  115. TreeD.get_full_name(),
  116. TreeE.get_full_name(),
  117. SubE_A.get_full_name(),
  118. SubE_A_A.get_full_name(),
  119. SubE_A_A_A.get_full_name(),
  120. SubE_A_A_A_A.get_full_name(),
  121. SubE_A_A_A_D.get_full_name(),
  122. ]
  123. @pytest_asyncio.fixture(loop_scope="function", scope="function")
  124. async def state_manager_redis(
  125. app_module_mock,
  126. ) -> AsyncGenerator[StateManager, None]:
  127. """Instance of state manager for redis only.
  128. Args:
  129. app_module_mock: The app module mock fixture.
  130. Yields:
  131. A state manager instance
  132. """
  133. app_module_mock.app = rx.App(state=Root)
  134. state_manager = app_module_mock.app.state_manager
  135. if not isinstance(state_manager, StateManagerRedis):
  136. pytest.skip("Test requires redis")
  137. yield state_manager
  138. await state_manager.close()
  139. @pytest.mark.asyncio
  140. @pytest.mark.parametrize(
  141. ("substate_cls", "exp_root_substates", "exp_root_dict_keys"),
  142. [
  143. (
  144. Root,
  145. [
  146. TreeA.get_name(),
  147. TreeB.get_name(),
  148. TreeC.get_name(),
  149. TreeD.get_name(),
  150. TreeE.get_name(),
  151. ],
  152. [
  153. TreeA.get_full_name(),
  154. SubA_A.get_full_name(),
  155. SubA_A_A.get_full_name(),
  156. SubA_A_A_A.get_full_name(),
  157. SubA_A_A_B.get_full_name(),
  158. SubA_A_A_C.get_full_name(),
  159. SubA_A_B.get_full_name(),
  160. SubA_B.get_full_name(),
  161. TreeB.get_full_name(),
  162. SubB_A.get_full_name(),
  163. SubB_B.get_full_name(),
  164. SubB_C.get_full_name(),
  165. SubB_C_A.get_full_name(),
  166. TreeC.get_full_name(),
  167. SubC_A.get_full_name(),
  168. SubE_A_A_A_B.get_full_name(),
  169. SubE_A_A_A_C.get_full_name(),
  170. *ALWAYS_COMPUTED_DICT_KEYS,
  171. ],
  172. ),
  173. (
  174. TreeA,
  175. (TreeA.get_name(), TreeD.get_name(), TreeE.get_name()),
  176. [
  177. TreeA.get_full_name(),
  178. SubA_A.get_full_name(),
  179. SubA_A_A.get_full_name(),
  180. SubA_A_A_A.get_full_name(),
  181. SubA_A_A_B.get_full_name(),
  182. SubA_A_A_C.get_full_name(),
  183. SubA_A_B.get_full_name(),
  184. SubA_B.get_full_name(),
  185. *ALWAYS_COMPUTED_DICT_KEYS,
  186. ],
  187. ),
  188. (
  189. SubA_A_A_A,
  190. [TreeA.get_name(), TreeD.get_name(), TreeE.get_name()],
  191. [
  192. TreeA.get_full_name(),
  193. SubA_A.get_full_name(),
  194. SubA_A_A.get_full_name(),
  195. SubA_A_A_A.get_full_name(),
  196. SubA_A_A_B.get_full_name(), # Cached var dep
  197. *ALWAYS_COMPUTED_DICT_KEYS,
  198. ],
  199. ),
  200. (
  201. TreeB,
  202. [TreeB.get_name(), TreeD.get_name(), TreeE.get_name()],
  203. [
  204. TreeB.get_full_name(),
  205. SubB_A.get_full_name(),
  206. SubB_B.get_full_name(),
  207. SubB_C.get_full_name(),
  208. SubB_C_A.get_full_name(),
  209. *ALWAYS_COMPUTED_DICT_KEYS,
  210. ],
  211. ),
  212. (
  213. SubB_B,
  214. [TreeB.get_name(), TreeD.get_name(), TreeE.get_name()],
  215. [
  216. TreeB.get_full_name(),
  217. SubB_B.get_full_name(),
  218. *ALWAYS_COMPUTED_DICT_KEYS,
  219. ],
  220. ),
  221. (
  222. SubB_C_A,
  223. [TreeB.get_name(), TreeD.get_name(), TreeE.get_name()],
  224. [
  225. TreeB.get_full_name(),
  226. SubB_C.get_full_name(),
  227. SubB_C_A.get_full_name(),
  228. *ALWAYS_COMPUTED_DICT_KEYS,
  229. ],
  230. ),
  231. (
  232. TreeC,
  233. [TreeC.get_name(), TreeD.get_name(), TreeE.get_name()],
  234. [
  235. TreeC.get_full_name(),
  236. SubC_A.get_full_name(),
  237. *ALWAYS_COMPUTED_DICT_KEYS,
  238. ],
  239. ),
  240. (
  241. TreeD,
  242. [TreeD.get_name(), TreeE.get_name()],
  243. [
  244. *ALWAYS_COMPUTED_DICT_KEYS,
  245. ],
  246. ),
  247. (
  248. TreeE,
  249. [TreeE.get_name(), TreeD.get_name()],
  250. [
  251. # Extra siblings of computed var included now.
  252. SubE_A_A_A_B.get_full_name(),
  253. SubE_A_A_A_C.get_full_name(),
  254. *ALWAYS_COMPUTED_DICT_KEYS,
  255. ],
  256. ),
  257. ],
  258. )
  259. async def test_get_state_tree(
  260. state_manager_redis,
  261. token,
  262. substate_cls,
  263. exp_root_substates,
  264. exp_root_dict_keys,
  265. ):
  266. """Test getting state trees and assert on which branches are retrieved.
  267. Args:
  268. state_manager_redis: The state manager redis fixture.
  269. token: The token fixture.
  270. substate_cls: The substate class to retrieve.
  271. exp_root_substates: The expected substates of the root state.
  272. exp_root_dict_keys: The expected keys of the root state dict.
  273. """
  274. state = await state_manager_redis.get_state(_substate_key(token, substate_cls))
  275. assert isinstance(state, Root)
  276. assert sorted(state.substates) == sorted(exp_root_substates)
  277. # Only computed vars should be returned
  278. assert state.get_delta() == ALWAYS_COMPUTED_VARS
  279. # All of TreeA, TreeD, and TreeE substates should be in the dict
  280. assert sorted(state.dict()) == sorted(exp_root_dict_keys)