test_component_state.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """Test that per-component state scaffold works and operates independently."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.state import State, _substate_key
  6. from reflex.testing import AppHarness
  7. from . import utils
  8. def ComponentStateApp():
  9. """App using per component state."""
  10. from typing import Generic, TypeVar
  11. import reflex as rx
  12. E = TypeVar("E")
  13. class MultiCounter(rx.ComponentState, Generic[E]):
  14. count: int = 0
  15. _be: E
  16. _be_int: int
  17. _be_str: str = "42"
  18. def increment(self):
  19. self.count += 1
  20. self._be = self.count # type: ignore
  21. @classmethod
  22. def get_component(cls, *children, **props):
  23. return rx.vstack(
  24. *children,
  25. rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"),
  26. rx.button(
  27. "Increment",
  28. on_click=cls.increment,
  29. id=f"button-{props.get('id', 'default')}",
  30. ),
  31. **props,
  32. )
  33. app = rx.App(state=rx.State) # noqa
  34. @rx.page()
  35. def index():
  36. mc_a = MultiCounter.create(id="a")
  37. mc_b = MultiCounter.create(id="b")
  38. assert mc_a.State != mc_b.State
  39. return rx.vstack(
  40. mc_a,
  41. mc_b,
  42. rx.button(
  43. "Inc A",
  44. on_click=mc_a.State.increment, # type: ignore
  45. id="inc-a",
  46. ),
  47. rx.text(
  48. mc_a.State.get_name() if mc_a.State is not None else "",
  49. id="a_state_name",
  50. ),
  51. rx.text(
  52. mc_b.State.get_name() if mc_b.State is not None else "",
  53. id="b_state_name",
  54. ),
  55. )
  56. @pytest.fixture()
  57. def component_state_app(tmp_path) -> Generator[AppHarness, None, None]:
  58. """Start ComponentStateApp app at tmp_path via AppHarness.
  59. Args:
  60. tmp_path: pytest tmp_path fixture
  61. Yields:
  62. running AppHarness instance
  63. """
  64. with AppHarness.create(
  65. root=tmp_path,
  66. app_source=ComponentStateApp, # type: ignore
  67. ) as harness:
  68. yield harness
  69. @pytest.mark.asyncio
  70. async def test_component_state_app(component_state_app: AppHarness):
  71. """Increment counters independently.
  72. Args:
  73. component_state_app: harness for ComponentStateApp app
  74. """
  75. assert component_state_app.app_instance is not None, "app is not running"
  76. driver = component_state_app.frontend()
  77. ss = utils.SessionStorage(driver)
  78. assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
  79. root_state_token = _substate_key(ss.get("token"), State)
  80. count_a = driver.find_element(By.ID, "count-a")
  81. count_b = driver.find_element(By.ID, "count-b")
  82. button_a = driver.find_element(By.ID, "button-a")
  83. button_b = driver.find_element(By.ID, "button-b")
  84. button_inc_a = driver.find_element(By.ID, "inc-a")
  85. # Check that backend vars in mixins are okay
  86. a_state_name = driver.find_element(By.ID, "a_state_name").text
  87. b_state_name = driver.find_element(By.ID, "b_state_name").text
  88. root_state = await component_state_app.get_state(root_state_token)
  89. a_state = root_state.substates[a_state_name]
  90. b_state = root_state.substates[b_state_name]
  91. assert a_state._backend_vars == a_state.backend_vars
  92. assert a_state._backend_vars == b_state._backend_vars
  93. assert a_state._backend_vars["_be"] is None
  94. assert a_state._backend_vars["_be_int"] == 0
  95. assert a_state._backend_vars["_be_str"] == "42"
  96. assert count_a.text == "0"
  97. button_a.click()
  98. assert component_state_app.poll_for_content(count_a, exp_not_equal="0") == "1"
  99. button_a.click()
  100. assert component_state_app.poll_for_content(count_a, exp_not_equal="1") == "2"
  101. button_inc_a.click()
  102. assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
  103. root_state = await component_state_app.get_state(root_state_token)
  104. a_state = root_state.substates[a_state_name]
  105. b_state = root_state.substates[b_state_name]
  106. assert a_state._backend_vars != a_state.backend_vars
  107. assert a_state._be == a_state._backend_vars["_be"] == 3
  108. assert b_state._be is None
  109. assert b_state._backend_vars["_be"] is None
  110. assert count_b.text == "0"
  111. button_b.click()
  112. assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1"
  113. button_b.click()
  114. assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"
  115. root_state = await component_state_app.get_state(root_state_token)
  116. a_state = root_state.substates[a_state_name]
  117. b_state = root_state.substates[b_state_name]
  118. assert b_state._backend_vars != b_state.backend_vars
  119. assert b_state._be == b_state._backend_vars["_be"] == 2