test_background_task.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. """Test @rx.background task functionality."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
  6. def BackgroundTask():
  7. """Test that background tasks work as expected."""
  8. import asyncio
  9. import pytest
  10. import reflex_chakra as rc
  11. import reflex as rx
  12. from reflex.state import ImmutableStateError
  13. class State(rx.State):
  14. counter: int = 0
  15. _task_id: int = 0
  16. iterations: int = 10
  17. @rx.background
  18. async def handle_event(self):
  19. async with self:
  20. self._task_id += 1
  21. for _ix in range(int(self.iterations)):
  22. async with self:
  23. self.counter += 1
  24. await asyncio.sleep(0.005)
  25. @rx.background
  26. async def handle_event_yield_only(self):
  27. async with self:
  28. self._task_id += 1
  29. for ix in range(int(self.iterations)):
  30. if ix % 2 == 0:
  31. yield State.increment_arbitrary(1) # type: ignore
  32. else:
  33. yield State.increment() # type: ignore
  34. await asyncio.sleep(0.005)
  35. def increment(self):
  36. self.counter += 1
  37. @rx.background
  38. async def increment_arbitrary(self, amount: int):
  39. async with self:
  40. self.counter += int(amount)
  41. def reset_counter(self):
  42. self.counter = 0
  43. async def blocking_pause(self):
  44. await asyncio.sleep(0.02)
  45. @rx.background
  46. async def non_blocking_pause(self):
  47. await asyncio.sleep(0.02)
  48. async def racy_task(self):
  49. async with self:
  50. self._task_id += 1
  51. for _ix in range(int(self.iterations)):
  52. async with self:
  53. self.counter += 1
  54. await asyncio.sleep(0.005)
  55. @rx.background
  56. async def handle_racy_event(self):
  57. await asyncio.gather(
  58. self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
  59. )
  60. @rx.background
  61. async def nested_async_with_self(self):
  62. async with self:
  63. self.counter += 1
  64. with pytest.raises(ImmutableStateError):
  65. async with self:
  66. self.counter += 1
  67. async def triple_count(self):
  68. third_state = await self.get_state(ThirdState)
  69. await third_state._triple_count()
  70. @rx.background
  71. async def yield_in_async_with_self(self):
  72. async with self:
  73. self.counter += 1
  74. yield
  75. self.counter += 1
  76. class OtherState(rx.State):
  77. @rx.background
  78. async def get_other_state(self):
  79. async with self:
  80. state = await self.get_state(State)
  81. state.counter += 1
  82. await state.triple_count()
  83. with pytest.raises(ImmutableStateError):
  84. await state.triple_count()
  85. with pytest.raises(ImmutableStateError):
  86. state.counter += 1
  87. async with state:
  88. state.counter += 1
  89. await state.triple_count()
  90. class ThirdState(rx.State):
  91. async def _triple_count(self):
  92. state = await self.get_state(State)
  93. state.counter *= 3
  94. def index() -> rx.Component:
  95. return rx.vstack(
  96. rc.input(
  97. id="token", value=State.router.session.client_token, is_read_only=True
  98. ),
  99. rx.heading(State.counter, id="counter"),
  100. rc.input(
  101. id="iterations",
  102. placeholder="Iterations",
  103. value=State.iterations.to_string(), # type: ignore
  104. on_change=State.set_iterations, # type: ignore
  105. ),
  106. rx.button(
  107. "Delayed Increment",
  108. on_click=State.handle_event,
  109. id="delayed-increment",
  110. ),
  111. rx.button(
  112. "Yield Increment",
  113. on_click=State.handle_event_yield_only,
  114. id="yield-increment",
  115. ),
  116. rx.button("Increment 1", on_click=State.increment, id="increment"),
  117. rx.button(
  118. "Blocking Pause",
  119. on_click=State.blocking_pause,
  120. id="blocking-pause",
  121. ),
  122. rx.button(
  123. "Non-Blocking Pause",
  124. on_click=State.non_blocking_pause,
  125. id="non-blocking-pause",
  126. ),
  127. rx.button(
  128. "Racy Increment (x4)",
  129. on_click=State.handle_racy_event,
  130. id="racy-increment",
  131. ),
  132. rx.button(
  133. "Nested Async with Self",
  134. on_click=State.nested_async_with_self,
  135. id="nested-async-with-self",
  136. ),
  137. rx.button(
  138. "Increment from OtherState",
  139. on_click=OtherState.get_other_state,
  140. id="increment-from-other-state",
  141. ),
  142. rx.button(
  143. "Yield in Async with Self",
  144. on_click=State.yield_in_async_with_self,
  145. id="yield-in-async-with-self",
  146. ),
  147. rx.button("Reset", on_click=State.reset_counter, id="reset"),
  148. )
  149. app = rx.App(state=rx.State)
  150. app.add_page(index)
  151. @pytest.fixture(scope="module")
  152. def background_task(
  153. tmp_path_factory,
  154. ) -> Generator[AppHarness, None, None]:
  155. """Start BackgroundTask app at tmp_path via AppHarness.
  156. Args:
  157. tmp_path_factory: pytest tmp_path_factory fixture
  158. Yields:
  159. running AppHarness instance
  160. """
  161. with AppHarness.create(
  162. root=tmp_path_factory.mktemp(f"background_task"),
  163. app_source=BackgroundTask, # type: ignore
  164. ) as harness:
  165. yield harness
  166. @pytest.fixture
  167. def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
  168. """Get an instance of the browser open to the background_task app.
  169. Args:
  170. background_task: harness for BackgroundTask app
  171. Yields:
  172. WebDriver instance.
  173. """
  174. assert background_task.app_instance is not None, "app is not running"
  175. driver = background_task.frontend()
  176. try:
  177. yield driver
  178. finally:
  179. driver.quit()
  180. @pytest.fixture()
  181. def token(background_task: AppHarness, driver: WebDriver) -> str:
  182. """Get a function that returns the active token.
  183. Args:
  184. background_task: harness for BackgroundTask app.
  185. driver: WebDriver instance.
  186. Returns:
  187. The token for the connected client
  188. """
  189. assert background_task.app_instance is not None
  190. token_input = driver.find_element(By.ID, "token")
  191. assert token_input
  192. # wait for the backend connection to send the token
  193. token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
  194. assert token is not None
  195. return token
  196. def test_background_task(
  197. background_task: AppHarness,
  198. driver: WebDriver,
  199. token: str,
  200. ):
  201. """Test that background tasks work as expected.
  202. Args:
  203. background_task: harness for BackgroundTask app.
  204. driver: WebDriver instance.
  205. token: The token for the connected client.
  206. """
  207. assert background_task.app_instance is not None
  208. # get a reference to all buttons
  209. delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
  210. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  211. increment_button = driver.find_element(By.ID, "increment")
  212. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  213. non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
  214. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  215. driver.find_element(By.ID, "reset")
  216. # get a reference to the counter
  217. counter = driver.find_element(By.ID, "counter")
  218. # get a reference to the iterations input
  219. iterations_input = driver.find_element(By.ID, "iterations")
  220. # kick off background tasks
  221. iterations_input.clear()
  222. iterations_input.send_keys("50")
  223. delayed_increment_button.click()
  224. blocking_pause_button.click()
  225. delayed_increment_button.click()
  226. for _ in range(10):
  227. increment_button.click()
  228. blocking_pause_button.click()
  229. delayed_increment_button.click()
  230. delayed_increment_button.click()
  231. yield_increment_button.click()
  232. racy_increment_button.click()
  233. non_blocking_pause_button.click()
  234. yield_increment_button.click()
  235. blocking_pause_button.click()
  236. yield_increment_button.click()
  237. for _ in range(10):
  238. increment_button.click()
  239. yield_increment_button.click()
  240. blocking_pause_button.click()
  241. assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
  242. # all tasks should have exited and cleaned up
  243. assert background_task._poll_for(
  244. lambda: not background_task.app_instance.background_tasks # type: ignore
  245. )
  246. def test_nested_async_with_self(
  247. background_task: AppHarness,
  248. driver: WebDriver,
  249. token: str,
  250. ):
  251. """Test that nested async with self in the same coroutine raises Exception.
  252. Args:
  253. background_task: harness for BackgroundTask app.
  254. driver: WebDriver instance.
  255. token: The token for the connected client.
  256. """
  257. assert background_task.app_instance is not None
  258. # get a reference to all buttons
  259. nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
  260. increment_button = driver.find_element(By.ID, "increment")
  261. # get a reference to the counter
  262. counter = driver.find_element(By.ID, "counter")
  263. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  264. nested_async_with_self_button.click()
  265. assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
  266. increment_button.click()
  267. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
  268. def test_get_state(
  269. background_task: AppHarness,
  270. driver: WebDriver,
  271. token: str,
  272. ):
  273. """Test that get_state returns a state bound to the correct StateProxy.
  274. Args:
  275. background_task: harness for BackgroundTask app.
  276. driver: WebDriver instance.
  277. token: The token for the connected client.
  278. """
  279. assert background_task.app_instance is not None
  280. # get a reference to all buttons
  281. other_state_button = driver.find_element(By.ID, "increment-from-other-state")
  282. increment_button = driver.find_element(By.ID, "increment")
  283. # get a reference to the counter
  284. counter = driver.find_element(By.ID, "counter")
  285. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  286. other_state_button.click()
  287. assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
  288. increment_button.click()
  289. assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
  290. def test_yield_in_async_with_self(
  291. background_task: AppHarness,
  292. driver: WebDriver,
  293. token: str,
  294. ):
  295. """Test that yielding inside async with self does not disable mutability.
  296. Args:
  297. background_task: harness for BackgroundTask app.
  298. driver: WebDriver instance.
  299. token: The token for the connected client.
  300. """
  301. assert background_task.app_instance is not None
  302. # get a reference to all buttons
  303. yield_in_async_with_self_button = driver.find_element(
  304. By.ID, "yield-in-async-with-self"
  305. )
  306. # get a reference to the counter
  307. counter = driver.find_element(By.ID, "counter")
  308. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  309. yield_in_async_with_self_button.click()
  310. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)