test_component_state.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. """Test that per-component state scaffold works and operates independently."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.state import State, _substate_key
  6. from reflex.testing import AppHarness
  7. from . import utils
  8. def ComponentStateApp():
  9. """App using per component state."""
  10. from typing import Generic, TypeVar
  11. import reflex as rx
  12. E = TypeVar("E")
  13. class MultiCounter(rx.ComponentState, Generic[E]):
  14. """ComponentState style."""
  15. count: int = 0
  16. _be: E
  17. _be_int: int
  18. _be_str: str = "42"
  19. @rx.event
  20. def increment(self):
  21. self.count += 1
  22. self._be = self.count # pyright: ignore [reportAttributeAccessIssue]
  23. @classmethod
  24. def get_component(cls, *children, **props):
  25. return rx.vstack(
  26. *children,
  27. rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"),
  28. rx.button(
  29. "Increment",
  30. on_click=cls.increment,
  31. id=f"button-{props.get('id', 'default')}",
  32. ),
  33. **props,
  34. )
  35. def multi_counter_func(id: str = "default") -> rx.Component:
  36. """Local-substate style.
  37. Args:
  38. id: identifier for this instance
  39. Returns:
  40. A new instance of the component with its own state.
  41. """
  42. class _Counter(rx.State):
  43. count: int = 0
  44. @rx.event
  45. def increment(self):
  46. self.count += 1
  47. return rx.vstack(
  48. rx.heading(_Counter.count, id=f"count-{id}"),
  49. rx.button(
  50. "Increment",
  51. on_click=_Counter.increment,
  52. id=f"button-{id}",
  53. ),
  54. State=_Counter,
  55. )
  56. app = rx.App(_state=rx.State) # noqa: F841
  57. @rx.page()
  58. def index():
  59. mc_a = MultiCounter.create(id="a")
  60. mc_b = MultiCounter.create(id="b")
  61. mc_c = multi_counter_func(id="c")
  62. mc_d = multi_counter_func(id="d")
  63. assert mc_a.State != mc_b.State
  64. assert mc_c.State != mc_d.State
  65. return rx.vstack(
  66. mc_a,
  67. mc_b,
  68. mc_c,
  69. mc_d,
  70. rx.button(
  71. "Inc A",
  72. on_click=mc_a.State.increment, # pyright: ignore [reportAttributeAccessIssue, reportOptionalMemberAccess]
  73. id="inc-a",
  74. ),
  75. rx.text(
  76. mc_a.State.get_name() if mc_a.State is not None else "",
  77. id="a_state_name",
  78. ),
  79. rx.text(
  80. mc_b.State.get_name() if mc_b.State is not None else "",
  81. id="b_state_name",
  82. ),
  83. )
  84. @pytest.fixture()
  85. def component_state_app(tmp_path) -> Generator[AppHarness, None, None]:
  86. """Start ComponentStateApp app at tmp_path via AppHarness.
  87. Args:
  88. tmp_path: pytest tmp_path fixture
  89. Yields:
  90. running AppHarness instance
  91. """
  92. with AppHarness.create(
  93. root=tmp_path,
  94. app_source=ComponentStateApp,
  95. ) as harness:
  96. yield harness
  97. @pytest.mark.asyncio
  98. async def test_component_state_app(component_state_app: AppHarness):
  99. """Increment counters independently.
  100. Args:
  101. component_state_app: harness for ComponentStateApp app
  102. """
  103. assert component_state_app.app_instance is not None, "app is not running"
  104. driver = component_state_app.frontend()
  105. ss = utils.SessionStorage(driver)
  106. assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
  107. root_state_token = _substate_key(ss.get("token"), State)
  108. count_a = driver.find_element(By.ID, "count-a")
  109. count_b = driver.find_element(By.ID, "count-b")
  110. button_a = driver.find_element(By.ID, "button-a")
  111. button_b = driver.find_element(By.ID, "button-b")
  112. button_inc_a = driver.find_element(By.ID, "inc-a")
  113. # Check that backend vars in mixins are okay
  114. a_state_name = driver.find_element(By.ID, "a_state_name").text
  115. b_state_name = driver.find_element(By.ID, "b_state_name").text
  116. root_state = await component_state_app.get_state(root_state_token)
  117. a_state = root_state.substates[a_state_name]
  118. b_state = root_state.substates[b_state_name]
  119. assert a_state._backend_vars == a_state.backend_vars
  120. assert a_state._backend_vars == b_state._backend_vars
  121. assert a_state._backend_vars["_be"] is None
  122. assert a_state._backend_vars["_be_int"] == 0
  123. assert a_state._backend_vars["_be_str"] == "42"
  124. assert count_a.text == "0"
  125. button_a.click()
  126. assert component_state_app.poll_for_content(count_a, exp_not_equal="0") == "1"
  127. button_a.click()
  128. assert component_state_app.poll_for_content(count_a, exp_not_equal="1") == "2"
  129. button_inc_a.click()
  130. assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
  131. root_state = await component_state_app.get_state(root_state_token)
  132. a_state = root_state.substates[a_state_name]
  133. b_state = root_state.substates[b_state_name]
  134. assert a_state._backend_vars != a_state.backend_vars
  135. assert a_state._be == a_state._backend_vars["_be"] == 3
  136. assert b_state._be is None
  137. assert b_state._backend_vars["_be"] is None
  138. assert count_b.text == "0"
  139. button_b.click()
  140. assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1"
  141. button_b.click()
  142. assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"
  143. root_state = await component_state_app.get_state(root_state_token)
  144. a_state = root_state.substates[a_state_name]
  145. b_state = root_state.substates[b_state_name]
  146. assert b_state._backend_vars != b_state.backend_vars
  147. assert b_state._be == b_state._backend_vars["_be"] == 2
  148. # Check locally-defined substate style
  149. count_c = driver.find_element(By.ID, "count-c")
  150. count_d = driver.find_element(By.ID, "count-d")
  151. button_c = driver.find_element(By.ID, "button-c")
  152. button_d = driver.find_element(By.ID, "button-d")
  153. assert component_state_app.poll_for_content(count_c, exp_not_equal="") == "0"
  154. assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
  155. button_c.click()
  156. assert component_state_app.poll_for_content(count_c, exp_not_equal="0") == "1"
  157. assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
  158. button_c.click()
  159. assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
  160. assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
  161. button_d.click()
  162. assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
  163. assert component_state_app.poll_for_content(count_d, exp_not_equal="0") == "1"