test_background_task.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """Test @rx.background task functionality."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
  6. def BackgroundTask():
  7. """Test that background tasks work as expected."""
  8. import asyncio
  9. import pytest
  10. import reflex as rx
  11. from reflex.state import ImmutableStateError
  12. class State(rx.State):
  13. counter: int = 0
  14. _task_id: int = 0
  15. iterations: int = 10
  16. @rx.background
  17. async def handle_event(self):
  18. async with self:
  19. self._task_id += 1
  20. for _ix in range(int(self.iterations)):
  21. async with self:
  22. self.counter += 1
  23. await asyncio.sleep(0.005)
  24. @rx.background
  25. async def handle_event_yield_only(self):
  26. async with self:
  27. self._task_id += 1
  28. for ix in range(int(self.iterations)):
  29. if ix % 2 == 0:
  30. yield State.increment_arbitrary(1) # type: ignore
  31. else:
  32. yield State.increment() # type: ignore
  33. await asyncio.sleep(0.005)
  34. def increment(self):
  35. self.counter += 1
  36. @rx.background
  37. async def increment_arbitrary(self, amount: int):
  38. async with self:
  39. self.counter += int(amount)
  40. def reset_counter(self):
  41. self.counter = 0
  42. async def blocking_pause(self):
  43. await asyncio.sleep(0.02)
  44. @rx.background
  45. async def non_blocking_pause(self):
  46. await asyncio.sleep(0.02)
  47. async def racy_task(self):
  48. async with self:
  49. self._task_id += 1
  50. for _ix in range(int(self.iterations)):
  51. async with self:
  52. self.counter += 1
  53. await asyncio.sleep(0.005)
  54. @rx.background
  55. async def handle_racy_event(self):
  56. await asyncio.gather(
  57. self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
  58. )
  59. @rx.background
  60. async def nested_async_with_self(self):
  61. async with self:
  62. self.counter += 1
  63. with pytest.raises(ImmutableStateError):
  64. async with self:
  65. self.counter += 1
  66. async def triple_count(self):
  67. third_state = await self.get_state(ThirdState)
  68. await third_state._triple_count()
  69. class OtherState(rx.State):
  70. @rx.background
  71. async def get_other_state(self):
  72. async with self:
  73. state = await self.get_state(State)
  74. state.counter += 1
  75. await state.triple_count()
  76. with pytest.raises(ImmutableStateError):
  77. await state.triple_count()
  78. with pytest.raises(ImmutableStateError):
  79. state.counter += 1
  80. async with state:
  81. state.counter += 1
  82. await state.triple_count()
  83. class ThirdState(rx.State):
  84. async def _triple_count(self):
  85. state = await self.get_state(State)
  86. state.counter *= 3
  87. def index() -> rx.Component:
  88. return rx.vstack(
  89. rx.chakra.input(
  90. id="token", value=State.router.session.client_token, is_read_only=True
  91. ),
  92. rx.heading(State.counter, id="counter"),
  93. rx.chakra.input(
  94. id="iterations",
  95. placeholder="Iterations",
  96. value=State.iterations.to_string(), # type: ignore
  97. on_change=State.set_iterations, # type: ignore
  98. ),
  99. rx.button(
  100. "Delayed Increment",
  101. on_click=State.handle_event,
  102. id="delayed-increment",
  103. ),
  104. rx.button(
  105. "Yield Increment",
  106. on_click=State.handle_event_yield_only,
  107. id="yield-increment",
  108. ),
  109. rx.button("Increment 1", on_click=State.increment, id="increment"),
  110. rx.button(
  111. "Blocking Pause",
  112. on_click=State.blocking_pause,
  113. id="blocking-pause",
  114. ),
  115. rx.button(
  116. "Non-Blocking Pause",
  117. on_click=State.non_blocking_pause,
  118. id="non-blocking-pause",
  119. ),
  120. rx.button(
  121. "Racy Increment (x4)",
  122. on_click=State.handle_racy_event,
  123. id="racy-increment",
  124. ),
  125. rx.button(
  126. "Nested Async with Self",
  127. on_click=State.nested_async_with_self,
  128. id="nested-async-with-self",
  129. ),
  130. rx.button(
  131. "Increment from OtherState",
  132. on_click=OtherState.get_other_state,
  133. id="increment-from-other-state",
  134. ),
  135. rx.button("Reset", on_click=State.reset_counter, id="reset"),
  136. )
  137. app = rx.App(state=rx.State)
  138. app.add_page(index)
  139. @pytest.fixture(scope="module")
  140. def background_task(
  141. tmp_path_factory,
  142. ) -> Generator[AppHarness, None, None]:
  143. """Start BackgroundTask app at tmp_path via AppHarness.
  144. Args:
  145. tmp_path_factory: pytest tmp_path_factory fixture
  146. Yields:
  147. running AppHarness instance
  148. """
  149. with AppHarness.create(
  150. root=tmp_path_factory.mktemp(f"background_task"),
  151. app_source=BackgroundTask, # type: ignore
  152. ) as harness:
  153. yield harness
  154. @pytest.fixture
  155. def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
  156. """Get an instance of the browser open to the background_task app.
  157. Args:
  158. background_task: harness for BackgroundTask app
  159. Yields:
  160. WebDriver instance.
  161. """
  162. assert background_task.app_instance is not None, "app is not running"
  163. driver = background_task.frontend()
  164. try:
  165. yield driver
  166. finally:
  167. driver.quit()
  168. @pytest.fixture()
  169. def token(background_task: AppHarness, driver: WebDriver) -> str:
  170. """Get a function that returns the active token.
  171. Args:
  172. background_task: harness for BackgroundTask app.
  173. driver: WebDriver instance.
  174. Returns:
  175. The token for the connected client
  176. """
  177. assert background_task.app_instance is not None
  178. token_input = driver.find_element(By.ID, "token")
  179. assert token_input
  180. # wait for the backend connection to send the token
  181. token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
  182. assert token is not None
  183. return token
  184. def test_background_task(
  185. background_task: AppHarness,
  186. driver: WebDriver,
  187. token: str,
  188. ):
  189. """Test that background tasks work as expected.
  190. Args:
  191. background_task: harness for BackgroundTask app.
  192. driver: WebDriver instance.
  193. token: The token for the connected client.
  194. """
  195. assert background_task.app_instance is not None
  196. # get a reference to all buttons
  197. delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
  198. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  199. increment_button = driver.find_element(By.ID, "increment")
  200. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  201. non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
  202. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  203. driver.find_element(By.ID, "reset")
  204. # get a reference to the counter
  205. counter = driver.find_element(By.ID, "counter")
  206. # get a reference to the iterations input
  207. iterations_input = driver.find_element(By.ID, "iterations")
  208. # kick off background tasks
  209. iterations_input.clear()
  210. iterations_input.send_keys("50")
  211. delayed_increment_button.click()
  212. blocking_pause_button.click()
  213. delayed_increment_button.click()
  214. for _ in range(10):
  215. increment_button.click()
  216. blocking_pause_button.click()
  217. delayed_increment_button.click()
  218. delayed_increment_button.click()
  219. yield_increment_button.click()
  220. racy_increment_button.click()
  221. non_blocking_pause_button.click()
  222. yield_increment_button.click()
  223. blocking_pause_button.click()
  224. yield_increment_button.click()
  225. for _ in range(10):
  226. increment_button.click()
  227. yield_increment_button.click()
  228. blocking_pause_button.click()
  229. assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
  230. # all tasks should have exited and cleaned up
  231. assert background_task._poll_for(
  232. lambda: not background_task.app_instance.background_tasks # type: ignore
  233. )
  234. def test_nested_async_with_self(
  235. background_task: AppHarness,
  236. driver: WebDriver,
  237. token: str,
  238. ):
  239. """Test that nested async with self in the same coroutine raises Exception.
  240. Args:
  241. background_task: harness for BackgroundTask app.
  242. driver: WebDriver instance.
  243. token: The token for the connected client.
  244. """
  245. assert background_task.app_instance is not None
  246. # get a reference to all buttons
  247. nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
  248. increment_button = driver.find_element(By.ID, "increment")
  249. # get a reference to the counter
  250. counter = driver.find_element(By.ID, "counter")
  251. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  252. nested_async_with_self_button.click()
  253. assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
  254. increment_button.click()
  255. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
  256. def test_get_state(
  257. background_task: AppHarness,
  258. driver: WebDriver,
  259. token: str,
  260. ):
  261. """Test that get_state returns a state bound to the correct StateProxy.
  262. Args:
  263. background_task: harness for BackgroundTask app.
  264. driver: WebDriver instance.
  265. token: The token for the connected client.
  266. """
  267. assert background_task.app_instance is not None
  268. # get a reference to all buttons
  269. other_state_button = driver.find_element(By.ID, "increment-from-other-state")
  270. increment_button = driver.find_element(By.ID, "increment")
  271. # get a reference to the counter
  272. counter = driver.find_element(By.ID, "counter")
  273. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  274. other_state_button.click()
  275. assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
  276. increment_button.click()
  277. assert background_task._poll_for(lambda: counter.text == "13", timeout=5)