test_background_task.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. """Test @rx.event(background=True) task functionality."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
  6. def BackgroundTask():
  7. """Test that background tasks work as expected."""
  8. import asyncio
  9. import pytest
  10. import reflex as rx
  11. from reflex.state import ImmutableStateError
  12. class State(rx.State):
  13. counter: int = 0
  14. _task_id: int = 0
  15. iterations: rx.Field[int] = rx.field(10)
  16. @rx.event
  17. def set_iterations(self, value: str):
  18. self.iterations = int(value)
  19. @rx.var
  20. async def counter_async_cv(self) -> int:
  21. """This exists solely as an integration test for background tasks triggering async var updates.
  22. Returns:
  23. The current value of the counter.
  24. """
  25. return self.counter
  26. @rx.event(background=True)
  27. async def handle_event(self):
  28. async with self:
  29. self._task_id += 1
  30. for _ix in range(int(self.iterations)):
  31. async with self:
  32. self.counter += 1
  33. await asyncio.sleep(0.005)
  34. @rx.event(background=True)
  35. async def handle_event_yield_only(self):
  36. async with self:
  37. self._task_id += 1
  38. for ix in range(int(self.iterations)):
  39. if ix % 2 == 0:
  40. yield State.increment_arbitrary(1)
  41. else:
  42. yield State.increment()
  43. await asyncio.sleep(0.005)
  44. @rx.event
  45. def increment(self):
  46. self.counter += 1
  47. @rx.event(background=True)
  48. async def increment_arbitrary(self, amount: int):
  49. async with self:
  50. self.counter += int(amount)
  51. @rx.event
  52. def reset_counter(self):
  53. self.counter = 0
  54. @rx.event
  55. async def blocking_pause(self):
  56. await asyncio.sleep(0.02)
  57. @rx.event(background=True)
  58. async def non_blocking_pause(self):
  59. await asyncio.sleep(0.02)
  60. async def racy_task(self):
  61. async with self:
  62. self._task_id += 1
  63. for _ix in range(int(self.iterations)):
  64. async with self:
  65. self.counter += 1
  66. await asyncio.sleep(0.005)
  67. @rx.event(background=True)
  68. async def handle_racy_event(self):
  69. await asyncio.gather(
  70. self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
  71. )
  72. @rx.event(background=True)
  73. async def nested_async_with_self(self):
  74. async with self:
  75. self.counter += 1
  76. with pytest.raises(ImmutableStateError):
  77. async with self:
  78. self.counter += 1
  79. async def triple_count(self):
  80. third_state = await self.get_state(ThirdState)
  81. await third_state._triple_count()
  82. @rx.event(background=True)
  83. async def yield_in_async_with_self(self):
  84. async with self:
  85. self.counter += 1
  86. yield
  87. self.counter += 1
  88. class OtherState(rx.State):
  89. @rx.event(background=True)
  90. async def get_other_state(self):
  91. async with self:
  92. state = await self.get_state(State)
  93. state.counter += 1
  94. await state.triple_count()
  95. with pytest.raises(ImmutableStateError):
  96. await state.triple_count()
  97. with pytest.raises(ImmutableStateError):
  98. state.counter += 1
  99. async with state:
  100. state.counter += 1
  101. await state.triple_count()
  102. class ThirdState(rx.State):
  103. async def _triple_count(self):
  104. state = await self.get_state(State)
  105. state.counter *= 3
  106. def index() -> rx.Component:
  107. return rx.vstack(
  108. rx.input(
  109. id="token", value=State.router.session.client_token, is_read_only=True
  110. ),
  111. rx.hstack(
  112. rx.heading(State.counter, id="counter"),
  113. rx.text(State.counter_async_cv, size="1", id="counter-async-cv"),
  114. ),
  115. rx.input(
  116. id="iterations",
  117. placeholder="Iterations",
  118. value=State.iterations.to_string(),
  119. on_change=State.set_iterations,
  120. ),
  121. rx.button(
  122. "Delayed Increment",
  123. on_click=State.handle_event,
  124. id="delayed-increment",
  125. ),
  126. rx.button(
  127. "Yield Increment",
  128. on_click=State.handle_event_yield_only,
  129. id="yield-increment",
  130. ),
  131. rx.button("Increment 1", on_click=State.increment, id="increment"),
  132. rx.button(
  133. "Blocking Pause",
  134. on_click=State.blocking_pause,
  135. id="blocking-pause",
  136. ),
  137. rx.button(
  138. "Non-Blocking Pause",
  139. on_click=State.non_blocking_pause,
  140. id="non-blocking-pause",
  141. ),
  142. rx.button(
  143. "Racy Increment (x4)",
  144. on_click=State.handle_racy_event,
  145. id="racy-increment",
  146. ),
  147. rx.button(
  148. "Nested Async with Self",
  149. on_click=State.nested_async_with_self,
  150. id="nested-async-with-self",
  151. ),
  152. rx.button(
  153. "Increment from OtherState",
  154. on_click=OtherState.get_other_state,
  155. id="increment-from-other-state",
  156. ),
  157. rx.button(
  158. "Yield in Async with Self",
  159. on_click=State.yield_in_async_with_self,
  160. id="yield-in-async-with-self",
  161. ),
  162. rx.button("Reset", on_click=State.reset_counter, id="reset"),
  163. )
  164. app = rx.App()
  165. app.add_page(index)
  166. @pytest.fixture(scope="module")
  167. def background_task(
  168. tmp_path_factory,
  169. ) -> Generator[AppHarness, None, None]:
  170. """Start BackgroundTask app at tmp_path via AppHarness.
  171. Args:
  172. tmp_path_factory: pytest tmp_path_factory fixture
  173. Yields:
  174. running AppHarness instance
  175. """
  176. with AppHarness.create(
  177. root=tmp_path_factory.mktemp("background_task"),
  178. app_source=BackgroundTask,
  179. ) as harness:
  180. yield harness
  181. @pytest.fixture
  182. def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
  183. """Get an instance of the browser open to the background_task app.
  184. Args:
  185. background_task: harness for BackgroundTask app
  186. Yields:
  187. WebDriver instance.
  188. """
  189. assert background_task.app_instance is not None, "app is not running"
  190. driver = background_task.frontend()
  191. try:
  192. yield driver
  193. finally:
  194. driver.quit()
  195. @pytest.fixture()
  196. def token(background_task: AppHarness, driver: WebDriver) -> str:
  197. """Get a function that returns the active token.
  198. Args:
  199. background_task: harness for BackgroundTask app.
  200. driver: WebDriver instance.
  201. Returns:
  202. The token for the connected client
  203. """
  204. assert background_task.app_instance is not None
  205. token_input = driver.find_element(By.ID, "token")
  206. assert token_input
  207. # wait for the backend connection to send the token
  208. token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
  209. assert token is not None
  210. return token
  211. def test_background_task(
  212. background_task: AppHarness,
  213. driver: WebDriver,
  214. token: str,
  215. ):
  216. """Test that background tasks work as expected.
  217. Args:
  218. background_task: harness for BackgroundTask app.
  219. driver: WebDriver instance.
  220. token: The token for the connected client.
  221. """
  222. assert background_task.app_instance is not None
  223. # get a reference to all buttons
  224. delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
  225. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  226. increment_button = driver.find_element(By.ID, "increment")
  227. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  228. non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
  229. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  230. driver.find_element(By.ID, "reset")
  231. # get a reference to the counter
  232. counter = driver.find_element(By.ID, "counter")
  233. counter_async_cv = driver.find_element(By.ID, "counter-async-cv")
  234. # get a reference to the iterations input
  235. iterations_input = driver.find_element(By.ID, "iterations")
  236. # kick off background tasks
  237. iterations_input.clear()
  238. iterations_input.send_keys("50")
  239. delayed_increment_button.click()
  240. blocking_pause_button.click()
  241. delayed_increment_button.click()
  242. for _ in range(10):
  243. increment_button.click()
  244. blocking_pause_button.click()
  245. delayed_increment_button.click()
  246. delayed_increment_button.click()
  247. yield_increment_button.click()
  248. racy_increment_button.click()
  249. non_blocking_pause_button.click()
  250. yield_increment_button.click()
  251. blocking_pause_button.click()
  252. yield_increment_button.click()
  253. for _ in range(10):
  254. increment_button.click()
  255. yield_increment_button.click()
  256. blocking_pause_button.click()
  257. assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
  258. assert background_task._poll_for(lambda: counter_async_cv.text == "620", timeout=40)
  259. # all tasks should have exited and cleaned up
  260. assert background_task._poll_for(
  261. lambda: not background_task.app_instance._background_tasks # pyright: ignore [reportOptionalMemberAccess]
  262. )
  263. def test_nested_async_with_self(
  264. background_task: AppHarness,
  265. driver: WebDriver,
  266. token: str,
  267. ):
  268. """Test that nested async with self in the same coroutine raises Exception.
  269. Args:
  270. background_task: harness for BackgroundTask app.
  271. driver: WebDriver instance.
  272. token: The token for the connected client.
  273. """
  274. assert background_task.app_instance is not None
  275. # get a reference to all buttons
  276. nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
  277. increment_button = driver.find_element(By.ID, "increment")
  278. # get a reference to the counter
  279. counter = driver.find_element(By.ID, "counter")
  280. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  281. nested_async_with_self_button.click()
  282. assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
  283. increment_button.click()
  284. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
  285. def test_get_state(
  286. background_task: AppHarness,
  287. driver: WebDriver,
  288. token: str,
  289. ):
  290. """Test that get_state returns a state bound to the correct StateProxy.
  291. Args:
  292. background_task: harness for BackgroundTask app.
  293. driver: WebDriver instance.
  294. token: The token for the connected client.
  295. """
  296. assert background_task.app_instance is not None
  297. # get a reference to all buttons
  298. other_state_button = driver.find_element(By.ID, "increment-from-other-state")
  299. increment_button = driver.find_element(By.ID, "increment")
  300. # get a reference to the counter
  301. counter = driver.find_element(By.ID, "counter")
  302. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  303. other_state_button.click()
  304. assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
  305. increment_button.click()
  306. assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
  307. def test_yield_in_async_with_self(
  308. background_task: AppHarness,
  309. driver: WebDriver,
  310. token: str,
  311. ):
  312. """Test that yielding inside async with self does not disable mutability.
  313. Args:
  314. background_task: harness for BackgroundTask app.
  315. driver: WebDriver instance.
  316. token: The token for the connected client.
  317. """
  318. assert background_task.app_instance is not None
  319. # get a reference to all buttons
  320. yield_in_async_with_self_button = driver.find_element(
  321. By.ID, "yield-in-async-with-self"
  322. )
  323. # get a reference to the counter
  324. counter = driver.find_element(By.ID, "counter")
  325. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  326. yield_in_async_with_self_button.click()
  327. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)