test_background_task.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. """Test @rx.background task functionality."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
  6. def BackgroundTask():
  7. """Test that background tasks work as expected."""
  8. import asyncio
  9. import pytest
  10. import reflex as rx
  11. from reflex.state import ImmutableStateError
  12. class State(rx.State):
  13. counter: int = 0
  14. _task_id: int = 0
  15. iterations: int = 10
  16. @rx.background
  17. async def handle_event(self):
  18. async with self:
  19. self._task_id += 1
  20. for _ix in range(int(self.iterations)):
  21. async with self:
  22. self.counter += 1
  23. await asyncio.sleep(0.005)
  24. @rx.background
  25. async def handle_event_yield_only(self):
  26. async with self:
  27. self._task_id += 1
  28. for ix in range(int(self.iterations)):
  29. if ix % 2 == 0:
  30. yield State.increment_arbitrary(1) # type: ignore
  31. else:
  32. yield State.increment() # type: ignore
  33. await asyncio.sleep(0.005)
  34. def increment(self):
  35. self.counter += 1
  36. @rx.background
  37. async def increment_arbitrary(self, amount: int):
  38. async with self:
  39. self.counter += int(amount)
  40. def reset_counter(self):
  41. self.counter = 0
  42. async def blocking_pause(self):
  43. await asyncio.sleep(0.02)
  44. @rx.background
  45. async def non_blocking_pause(self):
  46. await asyncio.sleep(0.02)
  47. async def racy_task(self):
  48. async with self:
  49. self._task_id += 1
  50. for _ix in range(int(self.iterations)):
  51. async with self:
  52. self.counter += 1
  53. await asyncio.sleep(0.005)
  54. @rx.background
  55. async def handle_racy_event(self):
  56. await asyncio.gather(
  57. self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
  58. )
  59. @rx.background
  60. async def nested_async_with_self(self):
  61. async with self:
  62. self.counter += 1
  63. with pytest.raises(ImmutableStateError):
  64. async with self:
  65. self.counter += 1
  66. async def triple_count(self):
  67. third_state = await self.get_state(ThirdState)
  68. await third_state._triple_count()
  69. @rx.background
  70. async def yield_in_async_with_self(self):
  71. async with self:
  72. self.counter += 1
  73. yield
  74. self.counter += 1
  75. class OtherState(rx.State):
  76. @rx.background
  77. async def get_other_state(self):
  78. async with self:
  79. state = await self.get_state(State)
  80. state.counter += 1
  81. await state.triple_count()
  82. with pytest.raises(ImmutableStateError):
  83. await state.triple_count()
  84. with pytest.raises(ImmutableStateError):
  85. state.counter += 1
  86. async with state:
  87. state.counter += 1
  88. await state.triple_count()
  89. class ThirdState(rx.State):
  90. async def _triple_count(self):
  91. state = await self.get_state(State)
  92. state.counter *= 3
  93. def index() -> rx.Component:
  94. return rx.vstack(
  95. rx.input(
  96. id="token", value=State.router.session.client_token, is_read_only=True
  97. ),
  98. rx.heading(State.counter, id="counter"),
  99. rx.input(
  100. id="iterations",
  101. placeholder="Iterations",
  102. value=State.iterations.to_string(), # type: ignore
  103. on_change=State.set_iterations, # type: ignore
  104. ),
  105. rx.button(
  106. "Delayed Increment",
  107. on_click=State.handle_event,
  108. id="delayed-increment",
  109. ),
  110. rx.button(
  111. "Yield Increment",
  112. on_click=State.handle_event_yield_only,
  113. id="yield-increment",
  114. ),
  115. rx.button("Increment 1", on_click=State.increment, id="increment"),
  116. rx.button(
  117. "Blocking Pause",
  118. on_click=State.blocking_pause,
  119. id="blocking-pause",
  120. ),
  121. rx.button(
  122. "Non-Blocking Pause",
  123. on_click=State.non_blocking_pause,
  124. id="non-blocking-pause",
  125. ),
  126. rx.button(
  127. "Racy Increment (x4)",
  128. on_click=State.handle_racy_event,
  129. id="racy-increment",
  130. ),
  131. rx.button(
  132. "Nested Async with Self",
  133. on_click=State.nested_async_with_self,
  134. id="nested-async-with-self",
  135. ),
  136. rx.button(
  137. "Increment from OtherState",
  138. on_click=OtherState.get_other_state,
  139. id="increment-from-other-state",
  140. ),
  141. rx.button(
  142. "Yield in Async with Self",
  143. on_click=State.yield_in_async_with_self,
  144. id="yield-in-async-with-self",
  145. ),
  146. rx.button("Reset", on_click=State.reset_counter, id="reset"),
  147. )
  148. app = rx.App(state=rx.State)
  149. app.add_page(index)
  150. @pytest.fixture(scope="module")
  151. def background_task(
  152. tmp_path_factory,
  153. ) -> Generator[AppHarness, None, None]:
  154. """Start BackgroundTask app at tmp_path via AppHarness.
  155. Args:
  156. tmp_path_factory: pytest tmp_path_factory fixture
  157. Yields:
  158. running AppHarness instance
  159. """
  160. with AppHarness.create(
  161. root=tmp_path_factory.mktemp(f"background_task"),
  162. app_source=BackgroundTask, # type: ignore
  163. ) as harness:
  164. yield harness
  165. @pytest.fixture
  166. def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
  167. """Get an instance of the browser open to the background_task app.
  168. Args:
  169. background_task: harness for BackgroundTask app
  170. Yields:
  171. WebDriver instance.
  172. """
  173. assert background_task.app_instance is not None, "app is not running"
  174. driver = background_task.frontend()
  175. try:
  176. yield driver
  177. finally:
  178. driver.quit()
  179. @pytest.fixture()
  180. def token(background_task: AppHarness, driver: WebDriver) -> str:
  181. """Get a function that returns the active token.
  182. Args:
  183. background_task: harness for BackgroundTask app.
  184. driver: WebDriver instance.
  185. Returns:
  186. The token for the connected client
  187. """
  188. assert background_task.app_instance is not None
  189. token_input = driver.find_element(By.ID, "token")
  190. assert token_input
  191. # wait for the backend connection to send the token
  192. token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
  193. assert token is not None
  194. return token
  195. def test_background_task(
  196. background_task: AppHarness,
  197. driver: WebDriver,
  198. token: str,
  199. ):
  200. """Test that background tasks work as expected.
  201. Args:
  202. background_task: harness for BackgroundTask app.
  203. driver: WebDriver instance.
  204. token: The token for the connected client.
  205. """
  206. assert background_task.app_instance is not None
  207. # get a reference to all buttons
  208. delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
  209. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  210. increment_button = driver.find_element(By.ID, "increment")
  211. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  212. non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
  213. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  214. driver.find_element(By.ID, "reset")
  215. # get a reference to the counter
  216. counter = driver.find_element(By.ID, "counter")
  217. # get a reference to the iterations input
  218. iterations_input = driver.find_element(By.ID, "iterations")
  219. # kick off background tasks
  220. iterations_input.clear()
  221. iterations_input.send_keys("50")
  222. delayed_increment_button.click()
  223. blocking_pause_button.click()
  224. delayed_increment_button.click()
  225. for _ in range(10):
  226. increment_button.click()
  227. blocking_pause_button.click()
  228. delayed_increment_button.click()
  229. delayed_increment_button.click()
  230. yield_increment_button.click()
  231. racy_increment_button.click()
  232. non_blocking_pause_button.click()
  233. yield_increment_button.click()
  234. blocking_pause_button.click()
  235. yield_increment_button.click()
  236. for _ in range(10):
  237. increment_button.click()
  238. yield_increment_button.click()
  239. blocking_pause_button.click()
  240. assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
  241. # all tasks should have exited and cleaned up
  242. assert background_task._poll_for(
  243. lambda: not background_task.app_instance.background_tasks # type: ignore
  244. )
  245. def test_nested_async_with_self(
  246. background_task: AppHarness,
  247. driver: WebDriver,
  248. token: str,
  249. ):
  250. """Test that nested async with self in the same coroutine raises Exception.
  251. Args:
  252. background_task: harness for BackgroundTask app.
  253. driver: WebDriver instance.
  254. token: The token for the connected client.
  255. """
  256. assert background_task.app_instance is not None
  257. # get a reference to all buttons
  258. nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
  259. increment_button = driver.find_element(By.ID, "increment")
  260. # get a reference to the counter
  261. counter = driver.find_element(By.ID, "counter")
  262. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  263. nested_async_with_self_button.click()
  264. assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
  265. increment_button.click()
  266. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
  267. def test_get_state(
  268. background_task: AppHarness,
  269. driver: WebDriver,
  270. token: str,
  271. ):
  272. """Test that get_state returns a state bound to the correct StateProxy.
  273. Args:
  274. background_task: harness for BackgroundTask app.
  275. driver: WebDriver instance.
  276. token: The token for the connected client.
  277. """
  278. assert background_task.app_instance is not None
  279. # get a reference to all buttons
  280. other_state_button = driver.find_element(By.ID, "increment-from-other-state")
  281. increment_button = driver.find_element(By.ID, "increment")
  282. # get a reference to the counter
  283. counter = driver.find_element(By.ID, "counter")
  284. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  285. other_state_button.click()
  286. assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
  287. increment_button.click()
  288. assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
  289. def test_yield_in_async_with_self(
  290. background_task: AppHarness,
  291. driver: WebDriver,
  292. token: str,
  293. ):
  294. """Test that yielding inside async with self does not disable mutability.
  295. Args:
  296. background_task: harness for BackgroundTask app.
  297. driver: WebDriver instance.
  298. token: The token for the connected client.
  299. """
  300. assert background_task.app_instance is not None
  301. # get a reference to all buttons
  302. yield_in_async_with_self_button = driver.find_element(
  303. By.ID, "yield-in-async-with-self"
  304. )
  305. # get a reference to the counter
  306. counter = driver.find_element(By.ID, "counter")
  307. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  308. yield_in_async_with_self_button.click()
  309. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)