1
0

test_background_task.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. """Test @rx.background task functionality."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
  6. def BackgroundTask():
  7. """Test that background tasks work as expected."""
  8. import asyncio
  9. import reflex as rx
  10. class State(rx.State):
  11. counter: int = 0
  12. _task_id: int = 0
  13. iterations: int = 10
  14. @rx.background
  15. async def handle_event(self):
  16. async with self:
  17. self._task_id += 1
  18. for _ix in range(int(self.iterations)):
  19. async with self:
  20. self.counter += 1
  21. await asyncio.sleep(0.005)
  22. @rx.background
  23. async def handle_event_yield_only(self):
  24. async with self:
  25. self._task_id += 1
  26. for ix in range(int(self.iterations)):
  27. if ix % 2 == 0:
  28. yield State.increment_arbitrary(1) # type: ignore
  29. else:
  30. yield State.increment() # type: ignore
  31. await asyncio.sleep(0.005)
  32. def increment(self):
  33. self.counter += 1
  34. @rx.background
  35. async def increment_arbitrary(self, amount: int):
  36. async with self:
  37. self.counter += int(amount)
  38. def reset_counter(self):
  39. self.counter = 0
  40. async def blocking_pause(self):
  41. await asyncio.sleep(0.02)
  42. @rx.background
  43. async def non_blocking_pause(self):
  44. await asyncio.sleep(0.02)
  45. async def racy_task(self):
  46. async with self:
  47. self._task_id += 1
  48. for _ix in range(int(self.iterations)):
  49. async with self:
  50. self.counter += 1
  51. await asyncio.sleep(0.005)
  52. @rx.background
  53. async def handle_racy_event(self):
  54. await asyncio.gather(
  55. self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
  56. )
  57. def index() -> rx.Component:
  58. return rx.vstack(
  59. rx.chakra.input(
  60. id="token", value=State.router.session.client_token, is_read_only=True
  61. ),
  62. rx.heading(State.counter, id="counter"),
  63. rx.chakra.input(
  64. id="iterations",
  65. placeholder="Iterations",
  66. value=State.iterations.to_string(), # type: ignore
  67. on_change=State.set_iterations, # type: ignore
  68. ),
  69. rx.button(
  70. "Delayed Increment",
  71. on_click=State.handle_event,
  72. id="delayed-increment",
  73. ),
  74. rx.button(
  75. "Yield Increment",
  76. on_click=State.handle_event_yield_only,
  77. id="yield-increment",
  78. ),
  79. rx.button("Increment 1", on_click=State.increment, id="increment"),
  80. rx.button(
  81. "Blocking Pause",
  82. on_click=State.blocking_pause,
  83. id="blocking-pause",
  84. ),
  85. rx.button(
  86. "Non-Blocking Pause",
  87. on_click=State.non_blocking_pause,
  88. id="non-blocking-pause",
  89. ),
  90. rx.button(
  91. "Racy Increment (x4)",
  92. on_click=State.handle_racy_event,
  93. id="racy-increment",
  94. ),
  95. rx.button("Reset", on_click=State.reset_counter, id="reset"),
  96. )
  97. app = rx.App(state=rx.State)
  98. app.add_page(index)
  99. @pytest.fixture(scope="module")
  100. def background_task(
  101. tmp_path_factory,
  102. ) -> Generator[AppHarness, None, None]:
  103. """Start BackgroundTask app at tmp_path via AppHarness.
  104. Args:
  105. tmp_path_factory: pytest tmp_path_factory fixture
  106. Yields:
  107. running AppHarness instance
  108. """
  109. with AppHarness.create(
  110. root=tmp_path_factory.mktemp(f"background_task"),
  111. app_source=BackgroundTask, # type: ignore
  112. ) as harness:
  113. yield harness
  114. @pytest.fixture
  115. def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
  116. """Get an instance of the browser open to the background_task app.
  117. Args:
  118. background_task: harness for BackgroundTask app
  119. Yields:
  120. WebDriver instance.
  121. """
  122. assert background_task.app_instance is not None, "app is not running"
  123. driver = background_task.frontend()
  124. try:
  125. yield driver
  126. finally:
  127. driver.quit()
  128. @pytest.fixture()
  129. def token(background_task: AppHarness, driver: WebDriver) -> str:
  130. """Get a function that returns the active token.
  131. Args:
  132. background_task: harness for BackgroundTask app.
  133. driver: WebDriver instance.
  134. Returns:
  135. The token for the connected client
  136. """
  137. assert background_task.app_instance is not None
  138. token_input = driver.find_element(By.ID, "token")
  139. assert token_input
  140. # wait for the backend connection to send the token
  141. token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
  142. assert token is not None
  143. return token
  144. def test_background_task(
  145. background_task: AppHarness,
  146. driver: WebDriver,
  147. token: str,
  148. ):
  149. """Test that background tasks work as expected.
  150. Args:
  151. background_task: harness for BackgroundTask app.
  152. driver: WebDriver instance.
  153. token: The token for the connected client.
  154. """
  155. assert background_task.app_instance is not None
  156. # get a reference to all buttons
  157. delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
  158. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  159. increment_button = driver.find_element(By.ID, "increment")
  160. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  161. non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
  162. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  163. driver.find_element(By.ID, "reset")
  164. # get a reference to the counter
  165. counter = driver.find_element(By.ID, "counter")
  166. # get a reference to the iterations input
  167. iterations_input = driver.find_element(By.ID, "iterations")
  168. # kick off background tasks
  169. iterations_input.clear()
  170. iterations_input.send_keys("50")
  171. delayed_increment_button.click()
  172. blocking_pause_button.click()
  173. delayed_increment_button.click()
  174. for _ in range(10):
  175. increment_button.click()
  176. blocking_pause_button.click()
  177. delayed_increment_button.click()
  178. delayed_increment_button.click()
  179. yield_increment_button.click()
  180. racy_increment_button.click()
  181. non_blocking_pause_button.click()
  182. yield_increment_button.click()
  183. blocking_pause_button.click()
  184. yield_increment_button.click()
  185. for _ in range(10):
  186. increment_button.click()
  187. yield_increment_button.click()
  188. blocking_pause_button.click()
  189. assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
  190. # all tasks should have exited and cleaned up
  191. assert background_task._poll_for(
  192. lambda: not background_task.app_instance.background_tasks # type: ignore
  193. )