test_background_task.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. """Test @rx.event(background=True) task functionality."""
  2. from collections.abc import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from selenium.webdriver.remote.webdriver import WebDriver
  6. from reflex.testing import DEFAULT_TIMEOUT, AppHarness
  7. def BackgroundTask():
  8. """Test that background tasks work as expected."""
  9. import asyncio
  10. import pytest
  11. import reflex as rx
  12. from reflex.state import ImmutableStateError
  13. class State(rx.State):
  14. counter: int = 0
  15. _task_id: int = 0
  16. iterations: rx.Field[int] = rx.field(10)
  17. @rx.event
  18. def set_iterations(self, value: str):
  19. self.iterations = int(value)
  20. @rx.var
  21. async def counter_async_cv(self) -> int:
  22. """This exists solely as an integration test for background tasks triggering async var updates.
  23. Returns:
  24. The current value of the counter.
  25. """
  26. return self.counter
  27. @rx.event(background=True)
  28. async def handle_event(self):
  29. async with self:
  30. self._task_id += 1
  31. for _ix in range(int(self.iterations)):
  32. async with self:
  33. self.counter += 1
  34. await asyncio.sleep(0.005)
  35. @rx.event(background=True)
  36. async def handle_event_yield_only(self):
  37. async with self:
  38. self._task_id += 1
  39. for ix in range(int(self.iterations)):
  40. if ix % 2 == 0:
  41. yield State.increment_arbitrary(1)
  42. else:
  43. yield State.increment()
  44. await asyncio.sleep(0.005)
  45. @rx.event
  46. def increment(self):
  47. self.counter += 1
  48. @rx.event(background=True)
  49. async def increment_arbitrary(self, amount: int):
  50. async with self:
  51. self.counter += int(amount)
  52. @rx.event
  53. def reset_counter(self):
  54. self.counter = 0
  55. @rx.event
  56. async def blocking_pause(self):
  57. await asyncio.sleep(0.02)
  58. @rx.event(background=True)
  59. async def non_blocking_pause(self):
  60. await asyncio.sleep(0.02)
  61. async def racy_task(self):
  62. async with self:
  63. self._task_id += 1
  64. for _ix in range(int(self.iterations)):
  65. async with self:
  66. self.counter += 1
  67. await asyncio.sleep(0.005)
  68. @rx.event(background=True)
  69. async def handle_racy_event(self):
  70. await asyncio.gather(
  71. self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
  72. )
  73. @rx.event(background=True)
  74. async def nested_async_with_self(self):
  75. async with self:
  76. self.counter += 1
  77. with pytest.raises(ImmutableStateError):
  78. async with self:
  79. self.counter += 1
  80. async def triple_count(self):
  81. third_state = await self.get_state(ThirdState)
  82. await third_state._triple_count()
  83. @rx.event(background=True)
  84. async def yield_in_async_with_self(self):
  85. async with self:
  86. self.counter += 1
  87. yield
  88. self.counter += 1
  89. class OtherState(rx.State):
  90. @rx.event(background=True)
  91. async def get_other_state(self):
  92. async with self:
  93. state = await self.get_state(State)
  94. state.counter += 1
  95. await state.triple_count()
  96. with pytest.raises(ImmutableStateError):
  97. await state.triple_count()
  98. with pytest.raises(ImmutableStateError):
  99. state.counter += 1
  100. async with state:
  101. state.counter += 1
  102. await state.triple_count()
  103. class ThirdState(rx.State):
  104. async def _triple_count(self):
  105. state = await self.get_state(State)
  106. state.counter *= 3
  107. def index() -> rx.Component:
  108. return rx.vstack(
  109. rx.input(
  110. id="token", value=State.router.session.client_token, is_read_only=True
  111. ),
  112. rx.hstack(
  113. rx.heading(State.counter, id="counter"),
  114. rx.text(State.counter_async_cv, size="1", id="counter-async-cv"),
  115. ),
  116. rx.input(
  117. id="iterations",
  118. placeholder="Iterations",
  119. value=State.iterations.to_string(),
  120. on_change=State.set_iterations,
  121. ),
  122. rx.button(
  123. "Delayed Increment",
  124. on_click=State.handle_event,
  125. id="delayed-increment",
  126. ),
  127. rx.button(
  128. "Yield Increment",
  129. on_click=State.handle_event_yield_only,
  130. id="yield-increment",
  131. ),
  132. rx.button("Increment 1", on_click=State.increment, id="increment"),
  133. rx.button(
  134. "Blocking Pause",
  135. on_click=State.blocking_pause,
  136. id="blocking-pause",
  137. ),
  138. rx.button(
  139. "Non-Blocking Pause",
  140. on_click=State.non_blocking_pause,
  141. id="non-blocking-pause",
  142. ),
  143. rx.button(
  144. "Racy Increment (x4)",
  145. on_click=State.handle_racy_event,
  146. id="racy-increment",
  147. ),
  148. rx.button(
  149. "Nested Async with Self",
  150. on_click=State.nested_async_with_self,
  151. id="nested-async-with-self",
  152. ),
  153. rx.button(
  154. "Increment from OtherState",
  155. on_click=OtherState.get_other_state,
  156. id="increment-from-other-state",
  157. ),
  158. rx.button(
  159. "Yield in Async with Self",
  160. on_click=State.yield_in_async_with_self,
  161. id="yield-in-async-with-self",
  162. ),
  163. rx.button("Reset", on_click=State.reset_counter, id="reset"),
  164. )
  165. app = rx.App()
  166. app.add_page(index)
  167. @pytest.fixture(scope="module")
  168. def background_task(
  169. tmp_path_factory,
  170. ) -> Generator[AppHarness, None, None]:
  171. """Start BackgroundTask app at tmp_path via AppHarness.
  172. Args:
  173. tmp_path_factory: pytest tmp_path_factory fixture
  174. Yields:
  175. running AppHarness instance
  176. """
  177. with AppHarness.create(
  178. root=tmp_path_factory.mktemp("background_task"),
  179. app_source=BackgroundTask,
  180. ) as harness:
  181. yield harness
  182. @pytest.fixture
  183. def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
  184. """Get an instance of the browser open to the background_task app.
  185. Args:
  186. background_task: harness for BackgroundTask app
  187. Yields:
  188. WebDriver instance.
  189. """
  190. assert background_task.app_instance is not None, "app is not running"
  191. driver = background_task.frontend()
  192. try:
  193. yield driver
  194. finally:
  195. driver.quit()
  196. @pytest.fixture()
  197. def token(background_task: AppHarness, driver: WebDriver) -> str:
  198. """Get a function that returns the active token.
  199. Args:
  200. background_task: harness for BackgroundTask app.
  201. driver: WebDriver instance.
  202. Returns:
  203. The token for the connected client
  204. """
  205. assert background_task.app_instance is not None
  206. token_input = driver.find_element(By.ID, "token")
  207. assert token_input
  208. # wait for the backend connection to send the token
  209. token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
  210. assert token is not None
  211. return token
  212. def test_background_task(
  213. background_task: AppHarness,
  214. driver: WebDriver,
  215. token: str,
  216. ):
  217. """Test that background tasks work as expected.
  218. Args:
  219. background_task: harness for BackgroundTask app.
  220. driver: WebDriver instance.
  221. token: The token for the connected client.
  222. """
  223. assert background_task.app_instance is not None
  224. # get a reference to all buttons
  225. delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
  226. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  227. increment_button = driver.find_element(By.ID, "increment")
  228. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  229. non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
  230. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  231. driver.find_element(By.ID, "reset")
  232. # get a reference to the iterations input
  233. iterations_input = driver.find_element(By.ID, "iterations")
  234. # kick off background tasks
  235. iterations_input.clear()
  236. iterations_input.send_keys("50")
  237. delayed_increment_button.click()
  238. blocking_pause_button.click()
  239. delayed_increment_button.click()
  240. for _ in range(10):
  241. increment_button.click()
  242. blocking_pause_button.click()
  243. delayed_increment_button.click()
  244. delayed_increment_button.click()
  245. yield_increment_button.click()
  246. racy_increment_button.click()
  247. non_blocking_pause_button.click()
  248. yield_increment_button.click()
  249. blocking_pause_button.click()
  250. yield_increment_button.click()
  251. # Refresh a few times while the task is running.
  252. driver.refresh()
  253. driver.refresh()
  254. # Get new references after page reloads.
  255. increment_button = driver.find_element(By.ID, "increment")
  256. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  257. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  258. counter = driver.find_element(By.ID, "counter")
  259. counter_async_cv = driver.find_element(By.ID, "counter-async-cv")
  260. # Click more buttons.
  261. for _ in range(10):
  262. increment_button.click()
  263. yield_increment_button.click()
  264. blocking_pause_button.click()
  265. # Make sure the final total is what we expect.
  266. assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
  267. assert background_task._poll_for(lambda: counter_async_cv.text == "620", timeout=40)
  268. # all tasks should have exited and cleaned up
  269. assert background_task._poll_for(
  270. lambda: not background_task.app_instance._background_tasks # pyright: ignore [reportOptionalMemberAccess]
  271. )
  272. def test_nested_async_with_self(
  273. background_task: AppHarness,
  274. driver: WebDriver,
  275. token: str,
  276. ):
  277. """Test that nested async with self in the same coroutine raises Exception.
  278. Args:
  279. background_task: harness for BackgroundTask app.
  280. driver: WebDriver instance.
  281. token: The token for the connected client.
  282. """
  283. assert background_task.app_instance is not None
  284. # get a reference to all buttons
  285. nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
  286. increment_button = driver.find_element(By.ID, "increment")
  287. # get a reference to the counter
  288. counter = driver.find_element(By.ID, "counter")
  289. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  290. nested_async_with_self_button.click()
  291. assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
  292. increment_button.click()
  293. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
  294. def test_get_state(
  295. background_task: AppHarness,
  296. driver: WebDriver,
  297. token: str,
  298. ):
  299. """Test that get_state returns a state bound to the correct StateProxy.
  300. Args:
  301. background_task: harness for BackgroundTask app.
  302. driver: WebDriver instance.
  303. token: The token for the connected client.
  304. """
  305. assert background_task.app_instance is not None
  306. # get a reference to all buttons
  307. other_state_button = driver.find_element(By.ID, "increment-from-other-state")
  308. increment_button = driver.find_element(By.ID, "increment")
  309. # get a reference to the counter
  310. counter = driver.find_element(By.ID, "counter")
  311. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  312. other_state_button.click()
  313. assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
  314. increment_button.click()
  315. assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
  316. def test_yield_in_async_with_self(
  317. background_task: AppHarness,
  318. driver: WebDriver,
  319. token: str,
  320. ):
  321. """Test that yielding inside async with self does not disable mutability.
  322. Args:
  323. background_task: harness for BackgroundTask app.
  324. driver: WebDriver instance.
  325. token: The token for the connected client.
  326. """
  327. assert background_task.app_instance is not None
  328. # get a reference to all buttons
  329. yield_in_async_with_self_button = driver.find_element(
  330. By.ID, "yield-in-async-with-self"
  331. )
  332. # get a reference to the counter
  333. counter = driver.find_element(By.ID, "counter")
  334. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  335. yield_in_async_with_self_button.click()
  336. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
  337. def test_background_task_refresh(
  338. background_task: AppHarness,
  339. driver: WebDriver,
  340. token: str,
  341. ):
  342. """Test that background tasks keep working when the page is refreshed.
  343. Args:
  344. background_task: harness for BackgroundTask app.
  345. driver: WebDriver instance.
  346. token: The token for the connected client.
  347. """
  348. assert background_task.app_instance is not None
  349. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  350. driver.find_element(By.ID, "reset")
  351. # get a reference to the iterations input
  352. iterations_input = driver.find_element(By.ID, "iterations")
  353. # kick off background tasks
  354. iterations_input.clear()
  355. iterations_input.send_keys("50")
  356. racy_increment_button.click()
  357. # Refresh a few times while the task is running.
  358. driver.refresh()
  359. driver.refresh()
  360. # Get new references after page reloads.
  361. counter = driver.find_element(By.ID, "counter")
  362. counter_async_cv = driver.find_element(By.ID, "counter-async-cv")
  363. # Make sure the final total is what we expect.
  364. assert background_task._poll_for(lambda: counter.text == "200", timeout=40)
  365. assert background_task._poll_for(lambda: counter_async_cv.text == "200", timeout=40)
  366. # all tasks should have exited and cleaned up
  367. assert background_task._poll_for(
  368. lambda: not background_task.app_instance._background_tasks # pyright: ignore [reportOptionalMemberAccess]
  369. )