test_state.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951
  1. from typing import Dict, List
  2. import pytest
  3. from plotly.graph_objects import Figure
  4. from pynecone.base import Base
  5. from pynecone.constants import IS_HYDRATED, RouteVar
  6. from pynecone.event import Event, EventHandler
  7. from pynecone.state import State
  8. from pynecone.utils import format
  9. from pynecone.var import BaseVar, ComputedVar
  10. class Object(Base):
  11. """A test object fixture."""
  12. prop1: int = 42
  13. prop2: str = "hello"
  14. class TestState(State):
  15. """A test state."""
  16. # Set this class as not test one
  17. __test__ = False
  18. num1: int
  19. num2: float = 3.14
  20. key: str
  21. array: List[float] = [1, 2, 3.14]
  22. mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
  23. obj: Object = Object()
  24. complex: Dict[int, Object] = {1: Object(), 2: Object()}
  25. fig: Figure = Figure()
  26. @ComputedVar
  27. def sum(self) -> float:
  28. """Dynamically sum the numbers.
  29. Returns:
  30. The sum of the numbers.
  31. """
  32. return self.num1 + self.num2
  33. @ComputedVar
  34. def upper(self) -> str:
  35. """Uppercase the key.
  36. Returns:
  37. The uppercased key.
  38. """
  39. return self.key.upper()
  40. def do_something(self):
  41. """Do something."""
  42. pass
  43. class ChildState(TestState):
  44. """A child state fixture."""
  45. value: str
  46. count: int = 23
  47. def change_both(self, value: str, count: int):
  48. """Change both the value and count.
  49. Args:
  50. value: The new value.
  51. count: The new count.
  52. """
  53. self.value = value.upper()
  54. self.count = count * 2
  55. class ChildState2(TestState):
  56. """A child state fixture."""
  57. value: str
  58. class GrandchildState(ChildState):
  59. """A grandchild state fixture."""
  60. value2: str
  61. def do_nothing(self):
  62. """Do something."""
  63. pass
  64. @pytest.fixture
  65. def test_state() -> TestState:
  66. """A state.
  67. Returns:
  68. A test state.
  69. """
  70. return TestState() # type: ignore
  71. @pytest.fixture
  72. def child_state(test_state) -> ChildState:
  73. """A child state.
  74. Args:
  75. test_state: A test state.
  76. Returns:
  77. A test child state.
  78. """
  79. child_state = test_state.get_substate(["child_state"])
  80. assert child_state is not None
  81. return child_state
  82. @pytest.fixture
  83. def child_state2(test_state) -> ChildState2:
  84. """A second child state.
  85. Args:
  86. test_state: A test state.
  87. Returns:
  88. A second test child state.
  89. """
  90. child_state2 = test_state.get_substate(["child_state2"])
  91. assert child_state2 is not None
  92. return child_state2
  93. @pytest.fixture
  94. def grandchild_state(child_state) -> GrandchildState:
  95. """A state.
  96. Args:
  97. child_state: A child state.
  98. Returns:
  99. A test state.
  100. """
  101. grandchild_state = child_state.get_substate(["grandchild_state"])
  102. assert grandchild_state is not None
  103. return grandchild_state
  104. def test_base_class_vars(test_state):
  105. """Test that the class vars are set correctly.
  106. Args:
  107. test_state: A state.
  108. """
  109. fields = test_state.get_fields()
  110. cls = type(test_state)
  111. for field in fields:
  112. if field in test_state.get_skip_vars():
  113. continue
  114. prop = getattr(cls, field)
  115. assert isinstance(prop, BaseVar)
  116. assert prop.name == field
  117. assert cls.num1.type_ == int
  118. assert cls.num2.type_ == float
  119. assert cls.key.type_ == str
  120. def test_computed_class_var(test_state):
  121. """Test that the class computed vars are set correctly.
  122. Args:
  123. test_state: A state.
  124. """
  125. cls = type(test_state)
  126. vars = [(prop.name, prop.type_) for prop in cls.computed_vars.values()]
  127. assert ("sum", float) in vars
  128. assert ("upper", str) in vars
  129. def test_class_vars(test_state):
  130. """Test that the class vars are set correctly.
  131. Args:
  132. test_state: A state.
  133. """
  134. cls = type(test_state)
  135. assert set(cls.vars.keys()) == {
  136. IS_HYDRATED, # added by hydrate_middleware to all State
  137. "num1",
  138. "num2",
  139. "key",
  140. "array",
  141. "mapping",
  142. "obj",
  143. "complex",
  144. "sum",
  145. "upper",
  146. "fig",
  147. }
  148. def test_event_handlers(test_state):
  149. """Test that event handler is set correctly.
  150. Args:
  151. test_state: A state.
  152. """
  153. expected = {
  154. "do_something",
  155. "set_array",
  156. "set_complex",
  157. "set_fig",
  158. "set_key",
  159. "set_mapping",
  160. "set_num1",
  161. "set_num2",
  162. "set_obj",
  163. }
  164. cls = type(test_state)
  165. assert set(cls.event_handlers.keys()).intersection(expected) == expected
  166. def test_default_value(test_state):
  167. """Test that the default value of a var is correct.
  168. Args:
  169. test_state: A state.
  170. """
  171. assert test_state.num1 == 0
  172. assert test_state.num2 == 3.14
  173. assert test_state.key == ""
  174. assert test_state.sum == 3.14
  175. assert test_state.upper == ""
  176. def test_computed_vars(test_state):
  177. """Test that the computed var is computed correctly.
  178. Args:
  179. test_state: A state.
  180. """
  181. test_state.num1 = 1
  182. test_state.num2 = 4
  183. assert test_state.sum == 5
  184. test_state.key = "hello world"
  185. assert test_state.upper == "HELLO WORLD"
  186. def test_dict(test_state):
  187. """Test that the dict representation of a state is correct.
  188. Args:
  189. test_state: A state.
  190. """
  191. substates = {"child_state", "child_state2"}
  192. assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates
  193. assert (
  194. set(test_state.dict(include_computed=False).keys())
  195. == set(test_state.base_vars) | substates
  196. )
  197. def test_default_setters(test_state):
  198. """Test that we can set default values.
  199. Args:
  200. test_state: A state.
  201. """
  202. for prop_name in test_state.base_vars:
  203. # Each base var should have a default setter.
  204. assert hasattr(test_state, f"set_{prop_name}")
  205. def test_class_indexing_with_vars():
  206. """Test that we can index into a state var with another var."""
  207. prop = TestState.array[TestState.num1]
  208. assert str(prop) == "{test_state.array.at(test_state.num1)}"
  209. prop = TestState.mapping["a"][TestState.num1]
  210. assert str(prop) == '{test_state.mapping["a"].at(test_state.num1)}'
  211. def test_class_attributes():
  212. """Test that we can get class attributes."""
  213. prop = TestState.obj.prop1
  214. assert str(prop) == "{test_state.obj.prop1}"
  215. prop = TestState.complex[1].prop1
  216. assert str(prop) == "{test_state.complex[1].prop1}"
  217. def test_get_parent_state():
  218. """Test getting the parent state."""
  219. assert TestState.get_parent_state() is None
  220. assert ChildState.get_parent_state() == TestState
  221. assert ChildState2.get_parent_state() == TestState
  222. assert GrandchildState.get_parent_state() == ChildState
  223. def test_get_substates():
  224. """Test getting the substates."""
  225. assert TestState.get_substates() == {ChildState, ChildState2}
  226. assert ChildState.get_substates() == {GrandchildState}
  227. assert ChildState2.get_substates() == set()
  228. assert GrandchildState.get_substates() == set()
  229. def test_get_name():
  230. """Test getting the name of a state."""
  231. assert TestState.get_name() == "test_state"
  232. assert ChildState.get_name() == "child_state"
  233. assert ChildState2.get_name() == "child_state2"
  234. assert GrandchildState.get_name() == "grandchild_state"
  235. def test_get_full_name():
  236. """Test getting the full name."""
  237. assert TestState.get_full_name() == "test_state"
  238. assert ChildState.get_full_name() == "test_state.child_state"
  239. assert ChildState2.get_full_name() == "test_state.child_state2"
  240. assert GrandchildState.get_full_name() == "test_state.child_state.grandchild_state"
  241. def test_get_class_substate():
  242. """Test getting the substate of a class."""
  243. assert TestState.get_class_substate(("child_state",)) == ChildState
  244. assert TestState.get_class_substate(("child_state2",)) == ChildState2
  245. assert ChildState.get_class_substate(("grandchild_state",)) == GrandchildState
  246. assert (
  247. TestState.get_class_substate(("child_state", "grandchild_state"))
  248. == GrandchildState
  249. )
  250. with pytest.raises(ValueError):
  251. TestState.get_class_substate(("invalid_child",))
  252. with pytest.raises(ValueError):
  253. TestState.get_class_substate(
  254. (
  255. "child_state",
  256. "invalid_child",
  257. )
  258. )
  259. def test_get_class_var():
  260. """Test getting the var of a class."""
  261. assert TestState.get_class_var(("num1",)) == TestState.num1
  262. assert TestState.get_class_var(("num2",)) == TestState.num2
  263. assert ChildState.get_class_var(("value",)) == ChildState.value
  264. assert GrandchildState.get_class_var(("value2",)) == GrandchildState.value2
  265. assert TestState.get_class_var(("child_state", "value")) == ChildState.value
  266. assert (
  267. TestState.get_class_var(("child_state", "grandchild_state", "value2"))
  268. == GrandchildState.value2
  269. )
  270. assert (
  271. ChildState.get_class_var(("grandchild_state", "value2"))
  272. == GrandchildState.value2
  273. )
  274. with pytest.raises(ValueError):
  275. TestState.get_class_var(("invalid_var",))
  276. with pytest.raises(ValueError):
  277. TestState.get_class_var(
  278. (
  279. "child_state",
  280. "invalid_var",
  281. )
  282. )
  283. def test_set_class_var():
  284. """Test setting the var of a class."""
  285. with pytest.raises(AttributeError):
  286. TestState.num3 # type: ignore
  287. TestState._set_var(BaseVar(name="num3", type_=int).set_state(TestState))
  288. var = TestState.num3 # type: ignore
  289. assert var.name == "num3"
  290. assert var.type_ == int
  291. assert var.state == TestState.get_full_name()
  292. def test_set_parent_and_substates(test_state, child_state, grandchild_state):
  293. """Test setting the parent and substates.
  294. Args:
  295. test_state: A state.
  296. child_state: A child state.
  297. grandchild_state: A grandchild state.
  298. """
  299. assert len(test_state.substates) == 2
  300. assert set(test_state.substates) == {"child_state", "child_state2"}
  301. assert child_state.parent_state == test_state
  302. assert len(child_state.substates) == 1
  303. assert set(child_state.substates) == {"grandchild_state"}
  304. assert grandchild_state.parent_state == child_state
  305. assert len(grandchild_state.substates) == 0
  306. def test_get_child_attribute(test_state, child_state, child_state2, grandchild_state):
  307. """Test getting the attribute of a state.
  308. Args:
  309. test_state: A state.
  310. child_state: A child state.
  311. child_state2: A child state.
  312. grandchild_state: A grandchild state.
  313. """
  314. assert test_state.num1 == 0
  315. assert child_state.value == ""
  316. assert child_state2.value == ""
  317. assert child_state.count == 23
  318. assert grandchild_state.value2 == ""
  319. with pytest.raises(AttributeError):
  320. test_state.invalid
  321. with pytest.raises(AttributeError):
  322. test_state.child_state.invalid
  323. with pytest.raises(AttributeError):
  324. test_state.child_state.grandchild_state.invalid
  325. def test_set_child_attribute(test_state, child_state, grandchild_state):
  326. """Test setting the attribute of a state.
  327. Args:
  328. test_state: A state.
  329. child_state: A child state.
  330. grandchild_state: A grandchild state.
  331. """
  332. test_state.num1 = 10
  333. assert test_state.num1 == 10
  334. assert child_state.num1 == 10
  335. assert grandchild_state.num1 == 10
  336. grandchild_state.num1 = 5
  337. assert test_state.num1 == 5
  338. assert child_state.num1 == 5
  339. assert grandchild_state.num1 == 5
  340. child_state.value = "test"
  341. assert child_state.value == "test"
  342. assert grandchild_state.value == "test"
  343. grandchild_state.value = "test2"
  344. assert child_state.value == "test2"
  345. assert grandchild_state.value == "test2"
  346. grandchild_state.value2 = "test3"
  347. assert grandchild_state.value2 == "test3"
  348. def test_get_substate(test_state, child_state, child_state2, grandchild_state):
  349. """Test getting the substate of a state.
  350. Args:
  351. test_state: A state.
  352. child_state: A child state.
  353. child_state2: A child state.
  354. grandchild_state: A grandchild state.
  355. """
  356. assert test_state.get_substate(("child_state",)) == child_state
  357. assert test_state.get_substate(("child_state2",)) == child_state2
  358. assert (
  359. test_state.get_substate(("child_state", "grandchild_state")) == grandchild_state
  360. )
  361. assert child_state.get_substate(("grandchild_state",)) == grandchild_state
  362. with pytest.raises(ValueError):
  363. test_state.get_substate(("invalid",))
  364. with pytest.raises(ValueError):
  365. test_state.get_substate(("child_state", "invalid"))
  366. with pytest.raises(ValueError):
  367. test_state.get_substate(("child_state", "grandchild_state", "invalid"))
  368. def test_set_dirty_var(test_state):
  369. """Test changing state vars marks the value as dirty.
  370. Args:
  371. test_state: A state.
  372. """
  373. # Initially there should be no dirty vars.
  374. assert test_state.dirty_vars == set()
  375. # Setting a var should mark it as dirty.
  376. test_state.num1 = 1
  377. # assert test_state.dirty_vars == {"num1", "sum"}
  378. assert test_state.dirty_vars == {"num1"}
  379. # Setting another var should mark it as dirty.
  380. test_state.num2 = 2
  381. # assert test_state.dirty_vars == {"num1", "num2", "sum"}
  382. assert test_state.dirty_vars == {"num1", "num2"}
  383. # Cleaning the state should remove all dirty vars.
  384. test_state.clean()
  385. assert test_state.dirty_vars == set()
  386. def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_state):
  387. """Test changing substate vars marks the value as dirty.
  388. Args:
  389. test_state: A state.
  390. child_state: A child state.
  391. child_state2: A child state.
  392. grandchild_state: A grandchild state.
  393. """
  394. # Initially there should be no dirty vars.
  395. assert test_state.dirty_vars == set()
  396. assert child_state.dirty_vars == set()
  397. assert child_state2.dirty_vars == set()
  398. assert grandchild_state.dirty_vars == set()
  399. # Setting a var should mark it as dirty.
  400. child_state.value = "test"
  401. assert child_state.dirty_vars == {"value"}
  402. assert test_state.dirty_substates == {"child_state"}
  403. assert child_state.dirty_substates == set()
  404. # Cleaning the parent state should remove the dirty substate.
  405. test_state.clean()
  406. assert test_state.dirty_substates == set()
  407. assert child_state.dirty_vars == set()
  408. # Setting a var on the grandchild should bubble up.
  409. grandchild_state.value2 = "test2"
  410. assert child_state.dirty_substates == {"grandchild_state"}
  411. assert test_state.dirty_substates == {"child_state"}
  412. # Cleaning the middle state should keep the parent state dirty.
  413. child_state.clean()
  414. assert test_state.dirty_substates == {"child_state"}
  415. assert child_state.dirty_substates == set()
  416. assert grandchild_state.dirty_vars == set()
  417. def test_reset(test_state, child_state):
  418. """Test resetting the state.
  419. Args:
  420. test_state: A state.
  421. child_state: A child state.
  422. """
  423. # Set some values.
  424. test_state.num1 = 1
  425. test_state.num2 = 2
  426. child_state.value = "test"
  427. # Reset the state.
  428. test_state.reset()
  429. # The values should be reset.
  430. assert test_state.num1 == 0
  431. assert test_state.num2 == 3.14
  432. assert child_state.value == ""
  433. # The dirty vars should be reset.
  434. assert test_state.dirty_vars == set()
  435. assert child_state.dirty_vars == set()
  436. # The dirty substates should be reset.
  437. assert test_state.dirty_substates == set()
  438. @pytest.mark.asyncio
  439. async def test_process_event_simple(test_state):
  440. """Test processing an event.
  441. Args:
  442. test_state: A state.
  443. """
  444. assert test_state.num1 == 0
  445. event = Event(token="t", name="set_num1", payload={"value": 69})
  446. update = await test_state._process(event)
  447. # The event should update the value.
  448. assert test_state.num1 == 69
  449. # The delta should contain the changes, including computed vars.
  450. # assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
  451. assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
  452. assert update.events == []
  453. @pytest.mark.asyncio
  454. async def test_process_event_substate(test_state, child_state, grandchild_state):
  455. """Test processing an event on a substate.
  456. Args:
  457. test_state: A state.
  458. child_state: A child state.
  459. grandchild_state: A grandchild state.
  460. """
  461. # Events should bubble down to the substate.
  462. assert child_state.value == ""
  463. assert child_state.count == 23
  464. event = Event(
  465. token="t", name="child_state.change_both", payload={"value": "hi", "count": 12}
  466. )
  467. update = await test_state._process(event)
  468. assert child_state.value == "HI"
  469. assert child_state.count == 24
  470. assert update.delta == {
  471. "test_state": {"sum": 3.14, "upper": ""},
  472. "test_state.child_state": {"value": "HI", "count": 24},
  473. }
  474. test_state.clean()
  475. # Test with the granchild state.
  476. assert grandchild_state.value2 == ""
  477. event = Event(
  478. token="t",
  479. name="child_state.grandchild_state.set_value2",
  480. payload={"value": "new"},
  481. )
  482. update = await test_state._process(event)
  483. assert grandchild_state.value2 == "new"
  484. assert update.delta == {
  485. "test_state": {"sum": 3.14, "upper": ""},
  486. "test_state.child_state.grandchild_state": {"value2": "new"},
  487. }
  488. def test_format_event_handler():
  489. """Test formatting an event handler."""
  490. assert (
  491. format.format_event_handler(TestState.do_something) == "test_state.do_something" # type: ignore
  492. )
  493. assert (
  494. format.format_event_handler(ChildState.change_both) # type: ignore
  495. == "test_state.child_state.change_both"
  496. )
  497. assert (
  498. format.format_event_handler(GrandchildState.do_nothing) # type: ignore
  499. == "test_state.child_state.grandchild_state.do_nothing"
  500. )
  501. def test_get_token(test_state):
  502. assert test_state.get_token() == ""
  503. token = "b181904c-3953-4a79-dc18-ae9518c22f05"
  504. test_state.router_data = {RouteVar.CLIENT_TOKEN: token}
  505. assert test_state.get_token() == token
  506. def test_get_sid(test_state):
  507. """Test getting session id.
  508. Args:
  509. test_state: A state.
  510. """
  511. assert test_state.get_sid() == ""
  512. sid = "9fpxSzPb9aFMb4wFAAAH"
  513. test_state.router_data = {RouteVar.SESSION_ID: sid}
  514. assert test_state.get_sid() == sid
  515. def test_get_headers(test_state):
  516. """Test getting client headers.
  517. Args:
  518. test_state: A state.
  519. """
  520. assert test_state.get_headers() == {}
  521. headers = {"host": "localhost:8000", "connection": "keep-alive"}
  522. test_state.router_data = {RouteVar.HEADERS: headers}
  523. assert test_state.get_headers() == headers
  524. def test_get_client_ip(test_state):
  525. """Test getting client IP.
  526. Args:
  527. test_state: A state.
  528. """
  529. assert test_state.get_client_ip() == ""
  530. client_ip = "127.0.0.1"
  531. test_state.router_data = {RouteVar.CLIENT_IP: client_ip}
  532. assert test_state.get_client_ip() == client_ip
  533. def test_get_current_page(test_state):
  534. assert test_state.get_current_page() == ""
  535. route = "mypage/subpage"
  536. test_state.router_data = {RouteVar.PATH: route}
  537. assert test_state.get_current_page() == route
  538. def test_get_query_params(test_state):
  539. assert test_state.get_query_params() == {}
  540. params = {"p1": "a", "p2": "b"}
  541. test_state.router_data = {RouteVar.QUERY: params}
  542. assert test_state.get_query_params() == params
  543. def test_add_var(test_state):
  544. test_state.add_var("dynamic_int", int, 42)
  545. assert test_state.dynamic_int == 42
  546. test_state.add_var("dynamic_list", List[int], [5, 10])
  547. assert test_state.dynamic_list == [5, 10]
  548. assert test_state.dynamic_list == [5, 10]
  549. # how to test that one?
  550. # test_state.dynamic_list.append(15)
  551. # assert test_state.dynamic_list == [5, 10, 15]
  552. test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10})
  553. assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
  554. assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
  555. def test_add_var_default_handlers(test_state):
  556. test_state.add_var("rand_int", int, 10)
  557. assert "set_rand_int" in test_state.event_handlers
  558. assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler)
  559. class InterdependentState(State):
  560. """A state with 3 vars and 3 computed vars.
  561. x: a variable that no computed var depends on
  562. v1: a varable that one computed var directly depeneds on
  563. _v2: a backend variable that one computed var directly depends on
  564. v1x2: a computed var that depends on v1
  565. v2x2: a computed var that depends on backend var _v2
  566. v1x2x2: a computed var that depends on computed var v1x2
  567. """
  568. x: int = 0
  569. v1: int = 0
  570. _v2: int = 1
  571. @ComputedVar
  572. def v1x2(self) -> int:
  573. """Depends on var v1.
  574. Returns:
  575. Var v1 multiplied by 2
  576. """
  577. return self.v1 * 2
  578. @ComputedVar
  579. def v2x2(self) -> int:
  580. """Depends on backend var _v2.
  581. Returns:
  582. backend var _v2 multiplied by 2
  583. """
  584. return self._v2 * 2
  585. @ComputedVar
  586. def v1x2x2(self) -> int:
  587. """Depends on ComputedVar v1x2.
  588. Returns:
  589. ComputedVar v1x2 multiplied by 2
  590. """
  591. return self.v1x2 * 2
  592. @pytest.fixture
  593. def interdependent_state() -> State:
  594. """A state with varying dependency between vars.
  595. Returns:
  596. instance of InterdependentState
  597. """
  598. s = InterdependentState()
  599. s.dict() # prime initial relationships by accessing all ComputedVars
  600. return s
  601. # def test_not_dirty_computed_var_from_var(interdependent_state):
  602. # """Set Var that no ComputedVar depends on, expect no recalculation.
  603. # Args:
  604. # interdependent_state: A state with varying Var dependencies.
  605. # """
  606. # interdependent_state.x = 5
  607. # assert interdependent_state.get_delta() == {
  608. # interdependent_state.get_full_name(): {"x": 5},
  609. # }
  610. # def test_dirty_computed_var_from_var(interdependent_state):
  611. # """Set Var that ComputedVar depends on, expect recalculation.
  612. # The other ComputedVar depends on the changed ComputedVar and should also be
  613. # recalculated. No other ComputedVars should be recalculated.
  614. # Args:
  615. # interdependent_state: A state with varying Var dependencies.
  616. # """
  617. # interdependent_state.v1 = 1
  618. # assert interdependent_state.get_delta() == {
  619. # interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
  620. # }
  621. # def test_dirty_computed_var_from_backend_var(interdependent_state):
  622. # """Set backend var that ComputedVar depends on, expect recalculation.
  623. # Args:
  624. # interdependent_state: A state with varying Var dependencies.
  625. # """
  626. # interdependent_state._v2 = 2
  627. # assert interdependent_state.get_delta() == {
  628. # interdependent_state.get_full_name(): {"v2x2": 4},
  629. # }
  630. def test_per_state_backend_var(interdependent_state):
  631. """Set backend var on one instance, expect no affect in other instances.
  632. Args:
  633. interdependent_state: A state with varying Var dependencies.
  634. """
  635. s2 = InterdependentState()
  636. assert s2._v2 == interdependent_state._v2
  637. interdependent_state._v2 = 2
  638. assert s2._v2 != interdependent_state._v2
  639. s3 = InterdependentState()
  640. assert s3._v2 != interdependent_state._v2
  641. # both s2 and s3 should still have the default value
  642. assert s2._v2 == s3._v2
  643. # changing s2._v2 should not affect others
  644. s2._v2 = 4
  645. assert s2._v2 != interdependent_state._v2
  646. assert s2._v2 != s3._v2
  647. def test_child_state():
  648. """Test that the child state computed vars can reference parent state vars."""
  649. class MainState(State):
  650. v: int = 2
  651. class ChildState(MainState):
  652. @ComputedVar
  653. def rendered_var(self):
  654. return self.v
  655. ms = MainState()
  656. cs = ms.substates[ChildState.get_name()]
  657. assert ms.v == 2
  658. assert cs.v == 2
  659. assert cs.rendered_var == 2
  660. def test_conditional_computed_vars():
  661. """Test that computed vars can have conditionals."""
  662. class MainState(State):
  663. flag: bool = False
  664. t1: str = "a"
  665. t2: str = "b"
  666. @ComputedVar
  667. def rendered_var(self) -> str:
  668. if self.flag:
  669. return self.t1
  670. return self.t2
  671. ms = MainState()
  672. # Initially there are no dirty computed vars.
  673. assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
  674. assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
  675. assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
  676. assert ms.computed_vars["rendered_var"].deps(objclass=MainState) == {
  677. "flag",
  678. "t1",
  679. "t2",
  680. }
  681. def test_event_handlers_convert_to_fns(test_state, child_state):
  682. """Test that when the state is initialized, event handlers are converted to fns.
  683. Args:
  684. test_state: A state with event handlers.
  685. child_state: A child state with event handlers.
  686. """
  687. # The class instances should be event handlers.
  688. assert isinstance(TestState.do_something, EventHandler)
  689. assert isinstance(ChildState.change_both, EventHandler)
  690. # The object instances should be fns.
  691. test_state.do_something()
  692. child_state.change_both(value="goose", count=9)
  693. assert child_state.value == "GOOSE"
  694. assert child_state.count == 18
  695. def test_event_handlers_call_other_handlers():
  696. """Test that event handlers can call other event handlers."""
  697. class MainState(State):
  698. v: int = 0
  699. def set_v(self, v: int):
  700. self.v = v
  701. def set_v2(self, v: int):
  702. self.set_v(v)
  703. ms = MainState()
  704. ms.set_v2(1)
  705. assert ms.v == 1
  706. def test_computed_var_cached():
  707. """Test that a ComputedVar doesn't recalculate when accessed."""
  708. comp_v_calls = 0
  709. class ComputedState(State):
  710. v: int = 0
  711. @ComputedVar
  712. def comp_v(self) -> int:
  713. nonlocal comp_v_calls
  714. comp_v_calls += 1
  715. return self.v
  716. cs = ComputedState()
  717. assert cs.dict()["v"] == 0
  718. assert comp_v_calls == 1
  719. assert cs.dict()["comp_v"] == 0
  720. assert comp_v_calls == 1
  721. assert cs.comp_v == 0
  722. assert comp_v_calls == 1
  723. cs.v = 1
  724. assert comp_v_calls == 1
  725. assert cs.comp_v == 1
  726. assert comp_v_calls == 2