test_background_task.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. """Test @rx.event(background=True) task functionality."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
  6. def BackgroundTask():
  7. """Test that background tasks work as expected."""
  8. import asyncio
  9. import pytest
  10. import reflex as rx
  11. from reflex.state import ImmutableStateError
  12. class State(rx.State):
  13. counter: int = 0
  14. _task_id: int = 0
  15. iterations: rx.Field[int] = rx.field(10)
  16. @rx.event
  17. def set_iterations(self, value: str):
  18. self.iterations = int(value)
  19. @rx.event(background=True)
  20. async def handle_event(self):
  21. async with self:
  22. self._task_id += 1
  23. for _ix in range(int(self.iterations)):
  24. async with self:
  25. self.counter += 1
  26. await asyncio.sleep(0.005)
  27. @rx.event(background=True)
  28. async def handle_event_yield_only(self):
  29. async with self:
  30. self._task_id += 1
  31. for ix in range(int(self.iterations)):
  32. if ix % 2 == 0:
  33. yield State.increment_arbitrary(1)
  34. else:
  35. yield State.increment()
  36. await asyncio.sleep(0.005)
  37. @rx.event
  38. def increment(self):
  39. self.counter += 1
  40. @rx.event(background=True)
  41. async def increment_arbitrary(self, amount: int):
  42. async with self:
  43. self.counter += int(amount)
  44. @rx.event
  45. def reset_counter(self):
  46. self.counter = 0
  47. @rx.event
  48. async def blocking_pause(self):
  49. await asyncio.sleep(0.02)
  50. @rx.event(background=True)
  51. async def non_blocking_pause(self):
  52. await asyncio.sleep(0.02)
  53. async def racy_task(self):
  54. async with self:
  55. self._task_id += 1
  56. for _ix in range(int(self.iterations)):
  57. async with self:
  58. self.counter += 1
  59. await asyncio.sleep(0.005)
  60. @rx.event(background=True)
  61. async def handle_racy_event(self):
  62. await asyncio.gather(
  63. self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
  64. )
  65. @rx.event(background=True)
  66. async def nested_async_with_self(self):
  67. async with self:
  68. self.counter += 1
  69. with pytest.raises(ImmutableStateError):
  70. async with self:
  71. self.counter += 1
  72. async def triple_count(self):
  73. third_state = await self.get_state(ThirdState)
  74. await third_state._triple_count()
  75. @rx.event(background=True)
  76. async def yield_in_async_with_self(self):
  77. async with self:
  78. self.counter += 1
  79. yield
  80. self.counter += 1
  81. class OtherState(rx.State):
  82. @rx.event(background=True)
  83. async def get_other_state(self):
  84. async with self:
  85. state = await self.get_state(State)
  86. state.counter += 1
  87. await state.triple_count()
  88. with pytest.raises(ImmutableStateError):
  89. await state.triple_count()
  90. with pytest.raises(ImmutableStateError):
  91. state.counter += 1
  92. async with state:
  93. state.counter += 1
  94. await state.triple_count()
  95. class ThirdState(rx.State):
  96. async def _triple_count(self):
  97. state = await self.get_state(State)
  98. state.counter *= 3
  99. def index() -> rx.Component:
  100. return rx.vstack(
  101. rx.input(
  102. id="token", value=State.router.session.client_token, is_read_only=True
  103. ),
  104. rx.heading(State.counter, id="counter"),
  105. rx.input(
  106. id="iterations",
  107. placeholder="Iterations",
  108. value=State.iterations.to_string(),
  109. on_change=State.set_iterations,
  110. ),
  111. rx.button(
  112. "Delayed Increment",
  113. on_click=State.handle_event,
  114. id="delayed-increment",
  115. ),
  116. rx.button(
  117. "Yield Increment",
  118. on_click=State.handle_event_yield_only,
  119. id="yield-increment",
  120. ),
  121. rx.button("Increment 1", on_click=State.increment, id="increment"),
  122. rx.button(
  123. "Blocking Pause",
  124. on_click=State.blocking_pause,
  125. id="blocking-pause",
  126. ),
  127. rx.button(
  128. "Non-Blocking Pause",
  129. on_click=State.non_blocking_pause,
  130. id="non-blocking-pause",
  131. ),
  132. rx.button(
  133. "Racy Increment (x4)",
  134. on_click=State.handle_racy_event,
  135. id="racy-increment",
  136. ),
  137. rx.button(
  138. "Nested Async with Self",
  139. on_click=State.nested_async_with_self,
  140. id="nested-async-with-self",
  141. ),
  142. rx.button(
  143. "Increment from OtherState",
  144. on_click=OtherState.get_other_state,
  145. id="increment-from-other-state",
  146. ),
  147. rx.button(
  148. "Yield in Async with Self",
  149. on_click=State.yield_in_async_with_self,
  150. id="yield-in-async-with-self",
  151. ),
  152. rx.button("Reset", on_click=State.reset_counter, id="reset"),
  153. )
  154. app = rx.App(_state=rx.State)
  155. app.add_page(index)
  156. @pytest.fixture(scope="module")
  157. def background_task(
  158. tmp_path_factory,
  159. ) -> Generator[AppHarness, None, None]:
  160. """Start BackgroundTask app at tmp_path via AppHarness.
  161. Args:
  162. tmp_path_factory: pytest tmp_path_factory fixture
  163. Yields:
  164. running AppHarness instance
  165. """
  166. with AppHarness.create(
  167. root=tmp_path_factory.mktemp("background_task"),
  168. app_source=BackgroundTask,
  169. ) as harness:
  170. yield harness
  171. @pytest.fixture
  172. def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
  173. """Get an instance of the browser open to the background_task app.
  174. Args:
  175. background_task: harness for BackgroundTask app
  176. Yields:
  177. WebDriver instance.
  178. """
  179. assert background_task.app_instance is not None, "app is not running"
  180. driver = background_task.frontend()
  181. try:
  182. yield driver
  183. finally:
  184. driver.quit()
  185. @pytest.fixture()
  186. def token(background_task: AppHarness, driver: WebDriver) -> str:
  187. """Get a function that returns the active token.
  188. Args:
  189. background_task: harness for BackgroundTask app.
  190. driver: WebDriver instance.
  191. Returns:
  192. The token for the connected client
  193. """
  194. assert background_task.app_instance is not None
  195. token_input = driver.find_element(By.ID, "token")
  196. assert token_input
  197. # wait for the backend connection to send the token
  198. token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
  199. assert token is not None
  200. return token
  201. def test_background_task(
  202. background_task: AppHarness,
  203. driver: WebDriver,
  204. token: str,
  205. ):
  206. """Test that background tasks work as expected.
  207. Args:
  208. background_task: harness for BackgroundTask app.
  209. driver: WebDriver instance.
  210. token: The token for the connected client.
  211. """
  212. assert background_task.app_instance is not None
  213. # get a reference to all buttons
  214. delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
  215. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  216. increment_button = driver.find_element(By.ID, "increment")
  217. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  218. non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
  219. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  220. driver.find_element(By.ID, "reset")
  221. # get a reference to the counter
  222. counter = driver.find_element(By.ID, "counter")
  223. # get a reference to the iterations input
  224. iterations_input = driver.find_element(By.ID, "iterations")
  225. # kick off background tasks
  226. iterations_input.clear()
  227. iterations_input.send_keys("50")
  228. delayed_increment_button.click()
  229. blocking_pause_button.click()
  230. delayed_increment_button.click()
  231. for _ in range(10):
  232. increment_button.click()
  233. blocking_pause_button.click()
  234. delayed_increment_button.click()
  235. delayed_increment_button.click()
  236. yield_increment_button.click()
  237. racy_increment_button.click()
  238. non_blocking_pause_button.click()
  239. yield_increment_button.click()
  240. blocking_pause_button.click()
  241. yield_increment_button.click()
  242. for _ in range(10):
  243. increment_button.click()
  244. yield_increment_button.click()
  245. blocking_pause_button.click()
  246. assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
  247. # all tasks should have exited and cleaned up
  248. assert background_task._poll_for(
  249. lambda: not background_task.app_instance._background_tasks # pyright: ignore [reportOptionalMemberAccess]
  250. )
  251. def test_nested_async_with_self(
  252. background_task: AppHarness,
  253. driver: WebDriver,
  254. token: str,
  255. ):
  256. """Test that nested async with self in the same coroutine raises Exception.
  257. Args:
  258. background_task: harness for BackgroundTask app.
  259. driver: WebDriver instance.
  260. token: The token for the connected client.
  261. """
  262. assert background_task.app_instance is not None
  263. # get a reference to all buttons
  264. nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
  265. increment_button = driver.find_element(By.ID, "increment")
  266. # get a reference to the counter
  267. counter = driver.find_element(By.ID, "counter")
  268. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  269. nested_async_with_self_button.click()
  270. assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
  271. increment_button.click()
  272. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
  273. def test_get_state(
  274. background_task: AppHarness,
  275. driver: WebDriver,
  276. token: str,
  277. ):
  278. """Test that get_state returns a state bound to the correct StateProxy.
  279. Args:
  280. background_task: harness for BackgroundTask app.
  281. driver: WebDriver instance.
  282. token: The token for the connected client.
  283. """
  284. assert background_task.app_instance is not None
  285. # get a reference to all buttons
  286. other_state_button = driver.find_element(By.ID, "increment-from-other-state")
  287. increment_button = driver.find_element(By.ID, "increment")
  288. # get a reference to the counter
  289. counter = driver.find_element(By.ID, "counter")
  290. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  291. other_state_button.click()
  292. assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
  293. increment_button.click()
  294. assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
  295. def test_yield_in_async_with_self(
  296. background_task: AppHarness,
  297. driver: WebDriver,
  298. token: str,
  299. ):
  300. """Test that yielding inside async with self does not disable mutability.
  301. Args:
  302. background_task: harness for BackgroundTask app.
  303. driver: WebDriver instance.
  304. token: The token for the connected client.
  305. """
  306. assert background_task.app_instance is not None
  307. # get a reference to all buttons
  308. yield_in_async_with_self_button = driver.find_element(
  309. By.ID, "yield-in-async-with-self"
  310. )
  311. # get a reference to the counter
  312. counter = driver.find_element(By.ID, "counter")
  313. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  314. yield_in_async_with_self_button.click()
  315. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)