test_background_task.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. """Test @rx.background task functionality."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
  6. def BackgroundTask():
  7. """Test that background tasks work as expected."""
  8. import asyncio
  9. import pytest
  10. import reflex as rx
  11. from reflex.state import ImmutableStateError
  12. class State(rx.State):
  13. counter: int = 0
  14. _task_id: int = 0
  15. iterations: int = 10
  16. @rx.background
  17. @rx.event
  18. async def handle_event(self):
  19. async with self:
  20. self._task_id += 1
  21. for _ix in range(int(self.iterations)):
  22. async with self:
  23. self.counter += 1
  24. await asyncio.sleep(0.005)
  25. @rx.background
  26. @rx.event
  27. async def handle_event_yield_only(self):
  28. async with self:
  29. self._task_id += 1
  30. for ix in range(int(self.iterations)):
  31. if ix % 2 == 0:
  32. yield State.increment_arbitrary(1) # type: ignore
  33. else:
  34. yield State.increment() # type: ignore
  35. await asyncio.sleep(0.005)
  36. @rx.event
  37. def increment(self):
  38. self.counter += 1
  39. @rx.background
  40. async def increment_arbitrary(self, amount: int):
  41. async with self:
  42. self.counter += int(amount)
  43. @rx.event
  44. def reset_counter(self):
  45. self.counter = 0
  46. @rx.event
  47. async def blocking_pause(self):
  48. await asyncio.sleep(0.02)
  49. @rx.background
  50. @rx.event
  51. async def non_blocking_pause(self):
  52. await asyncio.sleep(0.02)
  53. async def racy_task(self):
  54. async with self:
  55. self._task_id += 1
  56. for _ix in range(int(self.iterations)):
  57. async with self:
  58. self.counter += 1
  59. await asyncio.sleep(0.005)
  60. @rx.background
  61. @rx.event
  62. async def handle_racy_event(self):
  63. await asyncio.gather(
  64. self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
  65. )
  66. @rx.background
  67. @rx.event
  68. async def nested_async_with_self(self):
  69. async with self:
  70. self.counter += 1
  71. with pytest.raises(ImmutableStateError):
  72. async with self:
  73. self.counter += 1
  74. async def triple_count(self):
  75. third_state = await self.get_state(ThirdState)
  76. await third_state._triple_count()
  77. @rx.background
  78. @rx.event
  79. async def yield_in_async_with_self(self):
  80. async with self:
  81. self.counter += 1
  82. yield
  83. self.counter += 1
  84. class OtherState(rx.State):
  85. @rx.background
  86. @rx.event
  87. async def get_other_state(self):
  88. async with self:
  89. state = await self.get_state(State)
  90. state.counter += 1
  91. await state.triple_count()
  92. with pytest.raises(ImmutableStateError):
  93. await state.triple_count()
  94. with pytest.raises(ImmutableStateError):
  95. state.counter += 1
  96. async with state:
  97. state.counter += 1
  98. await state.triple_count()
  99. class ThirdState(rx.State):
  100. async def _triple_count(self):
  101. state = await self.get_state(State)
  102. state.counter *= 3
  103. def index() -> rx.Component:
  104. return rx.vstack(
  105. rx.input(
  106. id="token", value=State.router.session.client_token, is_read_only=True
  107. ),
  108. rx.heading(State.counter, id="counter"),
  109. rx.input(
  110. id="iterations",
  111. placeholder="Iterations",
  112. value=State.iterations.to_string(), # type: ignore
  113. on_change=State.set_iterations, # type: ignore
  114. ),
  115. rx.button(
  116. "Delayed Increment",
  117. on_click=State.handle_event,
  118. id="delayed-increment",
  119. ),
  120. rx.button(
  121. "Yield Increment",
  122. on_click=State.handle_event_yield_only,
  123. id="yield-increment",
  124. ),
  125. rx.button("Increment 1", on_click=State.increment, id="increment"),
  126. rx.button(
  127. "Blocking Pause",
  128. on_click=State.blocking_pause,
  129. id="blocking-pause",
  130. ),
  131. rx.button(
  132. "Non-Blocking Pause",
  133. on_click=State.non_blocking_pause,
  134. id="non-blocking-pause",
  135. ),
  136. rx.button(
  137. "Racy Increment (x4)",
  138. on_click=State.handle_racy_event,
  139. id="racy-increment",
  140. ),
  141. rx.button(
  142. "Nested Async with Self",
  143. on_click=State.nested_async_with_self,
  144. id="nested-async-with-self",
  145. ),
  146. rx.button(
  147. "Increment from OtherState",
  148. on_click=OtherState.get_other_state,
  149. id="increment-from-other-state",
  150. ),
  151. rx.button(
  152. "Yield in Async with Self",
  153. on_click=State.yield_in_async_with_self,
  154. id="yield-in-async-with-self",
  155. ),
  156. rx.button("Reset", on_click=State.reset_counter, id="reset"),
  157. )
  158. app = rx.App(state=rx.State)
  159. app.add_page(index)
  160. @pytest.fixture(scope="module")
  161. def background_task(
  162. tmp_path_factory,
  163. ) -> Generator[AppHarness, None, None]:
  164. """Start BackgroundTask app at tmp_path via AppHarness.
  165. Args:
  166. tmp_path_factory: pytest tmp_path_factory fixture
  167. Yields:
  168. running AppHarness instance
  169. """
  170. with AppHarness.create(
  171. root=tmp_path_factory.mktemp(f"background_task"),
  172. app_source=BackgroundTask, # type: ignore
  173. ) as harness:
  174. yield harness
  175. @pytest.fixture
  176. def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
  177. """Get an instance of the browser open to the background_task app.
  178. Args:
  179. background_task: harness for BackgroundTask app
  180. Yields:
  181. WebDriver instance.
  182. """
  183. assert background_task.app_instance is not None, "app is not running"
  184. driver = background_task.frontend()
  185. try:
  186. yield driver
  187. finally:
  188. driver.quit()
  189. @pytest.fixture()
  190. def token(background_task: AppHarness, driver: WebDriver) -> str:
  191. """Get a function that returns the active token.
  192. Args:
  193. background_task: harness for BackgroundTask app.
  194. driver: WebDriver instance.
  195. Returns:
  196. The token for the connected client
  197. """
  198. assert background_task.app_instance is not None
  199. token_input = driver.find_element(By.ID, "token")
  200. assert token_input
  201. # wait for the backend connection to send the token
  202. token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
  203. assert token is not None
  204. return token
  205. def test_background_task(
  206. background_task: AppHarness,
  207. driver: WebDriver,
  208. token: str,
  209. ):
  210. """Test that background tasks work as expected.
  211. Args:
  212. background_task: harness for BackgroundTask app.
  213. driver: WebDriver instance.
  214. token: The token for the connected client.
  215. """
  216. assert background_task.app_instance is not None
  217. # get a reference to all buttons
  218. delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
  219. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  220. increment_button = driver.find_element(By.ID, "increment")
  221. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  222. non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
  223. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  224. driver.find_element(By.ID, "reset")
  225. # get a reference to the counter
  226. counter = driver.find_element(By.ID, "counter")
  227. # get a reference to the iterations input
  228. iterations_input = driver.find_element(By.ID, "iterations")
  229. # kick off background tasks
  230. iterations_input.clear()
  231. iterations_input.send_keys("50")
  232. delayed_increment_button.click()
  233. blocking_pause_button.click()
  234. delayed_increment_button.click()
  235. for _ in range(10):
  236. increment_button.click()
  237. blocking_pause_button.click()
  238. delayed_increment_button.click()
  239. delayed_increment_button.click()
  240. yield_increment_button.click()
  241. racy_increment_button.click()
  242. non_blocking_pause_button.click()
  243. yield_increment_button.click()
  244. blocking_pause_button.click()
  245. yield_increment_button.click()
  246. for _ in range(10):
  247. increment_button.click()
  248. yield_increment_button.click()
  249. blocking_pause_button.click()
  250. assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
  251. # all tasks should have exited and cleaned up
  252. assert background_task._poll_for(
  253. lambda: not background_task.app_instance.background_tasks # type: ignore
  254. )
  255. def test_nested_async_with_self(
  256. background_task: AppHarness,
  257. driver: WebDriver,
  258. token: str,
  259. ):
  260. """Test that nested async with self in the same coroutine raises Exception.
  261. Args:
  262. background_task: harness for BackgroundTask app.
  263. driver: WebDriver instance.
  264. token: The token for the connected client.
  265. """
  266. assert background_task.app_instance is not None
  267. # get a reference to all buttons
  268. nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
  269. increment_button = driver.find_element(By.ID, "increment")
  270. # get a reference to the counter
  271. counter = driver.find_element(By.ID, "counter")
  272. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  273. nested_async_with_self_button.click()
  274. assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
  275. increment_button.click()
  276. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
  277. def test_get_state(
  278. background_task: AppHarness,
  279. driver: WebDriver,
  280. token: str,
  281. ):
  282. """Test that get_state returns a state bound to the correct StateProxy.
  283. Args:
  284. background_task: harness for BackgroundTask app.
  285. driver: WebDriver instance.
  286. token: The token for the connected client.
  287. """
  288. assert background_task.app_instance is not None
  289. # get a reference to all buttons
  290. other_state_button = driver.find_element(By.ID, "increment-from-other-state")
  291. increment_button = driver.find_element(By.ID, "increment")
  292. # get a reference to the counter
  293. counter = driver.find_element(By.ID, "counter")
  294. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  295. other_state_button.click()
  296. assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
  297. increment_button.click()
  298. assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
  299. def test_yield_in_async_with_self(
  300. background_task: AppHarness,
  301. driver: WebDriver,
  302. token: str,
  303. ):
  304. """Test that yielding inside async with self does not disable mutability.
  305. Args:
  306. background_task: harness for BackgroundTask app.
  307. driver: WebDriver instance.
  308. token: The token for the connected client.
  309. """
  310. assert background_task.app_instance is not None
  311. # get a reference to all buttons
  312. yield_in_async_with_self_button = driver.find_element(
  313. By.ID, "yield-in-async-with-self"
  314. )
  315. # get a reference to the counter
  316. counter = driver.find_element(By.ID, "counter")
  317. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  318. yield_in_async_with_self_button.click()
  319. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)