test_component_state.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. """Test that per-component state scaffold works and operates independently."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.state import State, _substate_key
  6. from reflex.testing import AppHarness
  7. from . import utils
  8. def ComponentStateApp():
  9. """App using per component state."""
  10. from typing import Generic, TypeVar
  11. import reflex as rx
  12. E = TypeVar("E")
  13. class MultiCounter(rx.ComponentState, Generic[E]):
  14. """ComponentState style."""
  15. count: int = 0
  16. _be: E
  17. _be_int: int
  18. _be_str: str = "42"
  19. def increment(self):
  20. self.count += 1
  21. self._be = self.count # type: ignore
  22. @classmethod
  23. def get_component(cls, *children, **props):
  24. return rx.vstack(
  25. *children,
  26. rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"),
  27. rx.button(
  28. "Increment",
  29. on_click=cls.increment,
  30. id=f"button-{props.get('id', 'default')}",
  31. ),
  32. **props,
  33. )
  34. def multi_counter_func(id: str = "default") -> rx.Component:
  35. """Local-substate style.
  36. Args:
  37. id: identifier for this instance
  38. Returns:
  39. A new instance of the component with its own state.
  40. """
  41. class _Counter(rx.State):
  42. count: int = 0
  43. def increment(self):
  44. self.count += 1
  45. return rx.vstack(
  46. rx.heading(_Counter.count, id=f"count-{id}"),
  47. rx.button(
  48. "Increment",
  49. on_click=_Counter.increment,
  50. id=f"button-{id}",
  51. ),
  52. State=_Counter,
  53. )
  54. app = rx.App(state=rx.State) # noqa
  55. @rx.page()
  56. def index():
  57. mc_a = MultiCounter.create(id="a")
  58. mc_b = MultiCounter.create(id="b")
  59. mc_c = multi_counter_func(id="c")
  60. mc_d = multi_counter_func(id="d")
  61. assert mc_a.State != mc_b.State
  62. assert mc_c.State != mc_d.State
  63. return rx.vstack(
  64. mc_a,
  65. mc_b,
  66. mc_c,
  67. mc_d,
  68. rx.button(
  69. "Inc A",
  70. on_click=mc_a.State.increment, # type: ignore
  71. id="inc-a",
  72. ),
  73. rx.text(
  74. mc_a.State.get_name() if mc_a.State is not None else "",
  75. id="a_state_name",
  76. ),
  77. rx.text(
  78. mc_b.State.get_name() if mc_b.State is not None else "",
  79. id="b_state_name",
  80. ),
  81. )
  82. @pytest.fixture()
  83. def component_state_app(tmp_path) -> Generator[AppHarness, None, None]:
  84. """Start ComponentStateApp app at tmp_path via AppHarness.
  85. Args:
  86. tmp_path: pytest tmp_path fixture
  87. Yields:
  88. running AppHarness instance
  89. """
  90. with AppHarness.create(
  91. root=tmp_path,
  92. app_source=ComponentStateApp, # type: ignore
  93. ) as harness:
  94. yield harness
  95. @pytest.mark.asyncio
  96. async def test_component_state_app(component_state_app: AppHarness):
  97. """Increment counters independently.
  98. Args:
  99. component_state_app: harness for ComponentStateApp app
  100. """
  101. assert component_state_app.app_instance is not None, "app is not running"
  102. driver = component_state_app.frontend()
  103. ss = utils.SessionStorage(driver)
  104. assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
  105. root_state_token = _substate_key(ss.get("token"), State)
  106. count_a = driver.find_element(By.ID, "count-a")
  107. count_b = driver.find_element(By.ID, "count-b")
  108. button_a = driver.find_element(By.ID, "button-a")
  109. button_b = driver.find_element(By.ID, "button-b")
  110. button_inc_a = driver.find_element(By.ID, "inc-a")
  111. # Check that backend vars in mixins are okay
  112. a_state_name = driver.find_element(By.ID, "a_state_name").text
  113. b_state_name = driver.find_element(By.ID, "b_state_name").text
  114. root_state = await component_state_app.get_state(root_state_token)
  115. a_state = root_state.substates[a_state_name]
  116. b_state = root_state.substates[b_state_name]
  117. assert a_state._backend_vars == a_state.backend_vars
  118. assert a_state._backend_vars == b_state._backend_vars
  119. assert a_state._backend_vars["_be"] is None
  120. assert a_state._backend_vars["_be_int"] == 0
  121. assert a_state._backend_vars["_be_str"] == "42"
  122. assert count_a.text == "0"
  123. button_a.click()
  124. assert component_state_app.poll_for_content(count_a, exp_not_equal="0") == "1"
  125. button_a.click()
  126. assert component_state_app.poll_for_content(count_a, exp_not_equal="1") == "2"
  127. button_inc_a.click()
  128. assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
  129. root_state = await component_state_app.get_state(root_state_token)
  130. a_state = root_state.substates[a_state_name]
  131. b_state = root_state.substates[b_state_name]
  132. assert a_state._backend_vars != a_state.backend_vars
  133. assert a_state._be == a_state._backend_vars["_be"] == 3
  134. assert b_state._be is None
  135. assert b_state._backend_vars["_be"] is None
  136. assert count_b.text == "0"
  137. button_b.click()
  138. assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1"
  139. button_b.click()
  140. assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"
  141. root_state = await component_state_app.get_state(root_state_token)
  142. a_state = root_state.substates[a_state_name]
  143. b_state = root_state.substates[b_state_name]
  144. assert b_state._backend_vars != b_state.backend_vars
  145. assert b_state._be == b_state._backend_vars["_be"] == 2
  146. # Check locally-defined substate style
  147. count_c = driver.find_element(By.ID, "count-c")
  148. count_d = driver.find_element(By.ID, "count-d")
  149. button_c = driver.find_element(By.ID, "button-c")
  150. button_d = driver.find_element(By.ID, "button-d")
  151. assert component_state_app.poll_for_content(count_c, exp_not_equal="") == "0"
  152. assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
  153. button_c.click()
  154. assert component_state_app.poll_for_content(count_c, exp_not_equal="0") == "1"
  155. assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
  156. button_c.click()
  157. assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
  158. assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
  159. button_d.click()
  160. assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
  161. assert component_state_app.poll_for_content(count_d, exp_not_equal="0") == "1"