1
0

test_background_task.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. """Test @rx.event(background=True) task functionality."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
  6. def BackgroundTask():
  7. """Test that background tasks work as expected."""
  8. import asyncio
  9. import pytest
  10. import reflex as rx
  11. from reflex.state import ImmutableStateError
  12. class State(rx.State):
  13. counter: int = 0
  14. _task_id: int = 0
  15. iterations: int = 10
  16. @rx.event(background=True)
  17. async def handle_event(self):
  18. async with self:
  19. self._task_id += 1
  20. for _ix in range(int(self.iterations)):
  21. async with self:
  22. self.counter += 1
  23. await asyncio.sleep(0.005)
  24. @rx.event(background=True)
  25. async def handle_event_yield_only(self):
  26. async with self:
  27. self._task_id += 1
  28. for ix in range(int(self.iterations)):
  29. if ix % 2 == 0:
  30. yield State.increment_arbitrary(1) # type: ignore
  31. else:
  32. yield State.increment() # type: ignore
  33. await asyncio.sleep(0.005)
  34. @rx.event
  35. def increment(self):
  36. self.counter += 1
  37. @rx.event(background=True)
  38. async def increment_arbitrary(self, amount: int):
  39. async with self:
  40. self.counter += int(amount)
  41. @rx.event
  42. def reset_counter(self):
  43. self.counter = 0
  44. @rx.event
  45. async def blocking_pause(self):
  46. await asyncio.sleep(0.02)
  47. @rx.event(background=True)
  48. async def non_blocking_pause(self):
  49. await asyncio.sleep(0.02)
  50. async def racy_task(self):
  51. async with self:
  52. self._task_id += 1
  53. for _ix in range(int(self.iterations)):
  54. async with self:
  55. self.counter += 1
  56. await asyncio.sleep(0.005)
  57. @rx.event(background=True)
  58. async def handle_racy_event(self):
  59. await asyncio.gather(
  60. self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
  61. )
  62. @rx.event(background=True)
  63. async def nested_async_with_self(self):
  64. async with self:
  65. self.counter += 1
  66. with pytest.raises(ImmutableStateError):
  67. async with self:
  68. self.counter += 1
  69. async def triple_count(self):
  70. third_state = await self.get_state(ThirdState)
  71. await third_state._triple_count()
  72. @rx.event(background=True)
  73. async def yield_in_async_with_self(self):
  74. async with self:
  75. self.counter += 1
  76. yield
  77. self.counter += 1
  78. class OtherState(rx.State):
  79. @rx.event(background=True)
  80. async def get_other_state(self):
  81. async with self:
  82. state = await self.get_state(State)
  83. state.counter += 1
  84. await state.triple_count()
  85. with pytest.raises(ImmutableStateError):
  86. await state.triple_count()
  87. with pytest.raises(ImmutableStateError):
  88. state.counter += 1
  89. async with state:
  90. state.counter += 1
  91. await state.triple_count()
  92. class ThirdState(rx.State):
  93. async def _triple_count(self):
  94. state = await self.get_state(State)
  95. state.counter *= 3
  96. def index() -> rx.Component:
  97. return rx.vstack(
  98. rx.input(
  99. id="token", value=State.router.session.client_token, is_read_only=True
  100. ),
  101. rx.heading(State.counter, id="counter"),
  102. rx.input(
  103. id="iterations",
  104. placeholder="Iterations",
  105. value=State.iterations.to_string(), # type: ignore
  106. on_change=State.set_iterations, # type: ignore
  107. ),
  108. rx.button(
  109. "Delayed Increment",
  110. on_click=State.handle_event,
  111. id="delayed-increment",
  112. ),
  113. rx.button(
  114. "Yield Increment",
  115. on_click=State.handle_event_yield_only,
  116. id="yield-increment",
  117. ),
  118. rx.button("Increment 1", on_click=State.increment, id="increment"),
  119. rx.button(
  120. "Blocking Pause",
  121. on_click=State.blocking_pause,
  122. id="blocking-pause",
  123. ),
  124. rx.button(
  125. "Non-Blocking Pause",
  126. on_click=State.non_blocking_pause,
  127. id="non-blocking-pause",
  128. ),
  129. rx.button(
  130. "Racy Increment (x4)",
  131. on_click=State.handle_racy_event,
  132. id="racy-increment",
  133. ),
  134. rx.button(
  135. "Nested Async with Self",
  136. on_click=State.nested_async_with_self,
  137. id="nested-async-with-self",
  138. ),
  139. rx.button(
  140. "Increment from OtherState",
  141. on_click=OtherState.get_other_state,
  142. id="increment-from-other-state",
  143. ),
  144. rx.button(
  145. "Yield in Async with Self",
  146. on_click=State.yield_in_async_with_self,
  147. id="yield-in-async-with-self",
  148. ),
  149. rx.button("Reset", on_click=State.reset_counter, id="reset"),
  150. )
  151. app = rx.App(state=rx.State)
  152. app.add_page(index)
  153. @pytest.fixture(scope="module")
  154. def background_task(
  155. tmp_path_factory,
  156. ) -> Generator[AppHarness, None, None]:
  157. """Start BackgroundTask app at tmp_path via AppHarness.
  158. Args:
  159. tmp_path_factory: pytest tmp_path_factory fixture
  160. Yields:
  161. running AppHarness instance
  162. """
  163. with AppHarness.create(
  164. root=tmp_path_factory.mktemp("background_task"),
  165. app_source=BackgroundTask,
  166. ) as harness:
  167. yield harness
  168. @pytest.fixture
  169. def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
  170. """Get an instance of the browser open to the background_task app.
  171. Args:
  172. background_task: harness for BackgroundTask app
  173. Yields:
  174. WebDriver instance.
  175. """
  176. assert background_task.app_instance is not None, "app is not running"
  177. driver = background_task.frontend()
  178. try:
  179. yield driver
  180. finally:
  181. driver.quit()
  182. @pytest.fixture()
  183. def token(background_task: AppHarness, driver: WebDriver) -> str:
  184. """Get a function that returns the active token.
  185. Args:
  186. background_task: harness for BackgroundTask app.
  187. driver: WebDriver instance.
  188. Returns:
  189. The token for the connected client
  190. """
  191. assert background_task.app_instance is not None
  192. token_input = driver.find_element(By.ID, "token")
  193. assert token_input
  194. # wait for the backend connection to send the token
  195. token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
  196. assert token is not None
  197. return token
  198. def test_background_task(
  199. background_task: AppHarness,
  200. driver: WebDriver,
  201. token: str,
  202. ):
  203. """Test that background tasks work as expected.
  204. Args:
  205. background_task: harness for BackgroundTask app.
  206. driver: WebDriver instance.
  207. token: The token for the connected client.
  208. """
  209. assert background_task.app_instance is not None
  210. # get a reference to all buttons
  211. delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
  212. yield_increment_button = driver.find_element(By.ID, "yield-increment")
  213. increment_button = driver.find_element(By.ID, "increment")
  214. blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
  215. non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
  216. racy_increment_button = driver.find_element(By.ID, "racy-increment")
  217. driver.find_element(By.ID, "reset")
  218. # get a reference to the counter
  219. counter = driver.find_element(By.ID, "counter")
  220. # get a reference to the iterations input
  221. iterations_input = driver.find_element(By.ID, "iterations")
  222. # kick off background tasks
  223. iterations_input.clear()
  224. iterations_input.send_keys("50")
  225. delayed_increment_button.click()
  226. blocking_pause_button.click()
  227. delayed_increment_button.click()
  228. for _ in range(10):
  229. increment_button.click()
  230. blocking_pause_button.click()
  231. delayed_increment_button.click()
  232. delayed_increment_button.click()
  233. yield_increment_button.click()
  234. racy_increment_button.click()
  235. non_blocking_pause_button.click()
  236. yield_increment_button.click()
  237. blocking_pause_button.click()
  238. yield_increment_button.click()
  239. for _ in range(10):
  240. increment_button.click()
  241. yield_increment_button.click()
  242. blocking_pause_button.click()
  243. assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
  244. # all tasks should have exited and cleaned up
  245. assert background_task._poll_for(
  246. lambda: not background_task.app_instance.background_tasks # type: ignore
  247. )
  248. def test_nested_async_with_self(
  249. background_task: AppHarness,
  250. driver: WebDriver,
  251. token: str,
  252. ):
  253. """Test that nested async with self in the same coroutine raises Exception.
  254. Args:
  255. background_task: harness for BackgroundTask app.
  256. driver: WebDriver instance.
  257. token: The token for the connected client.
  258. """
  259. assert background_task.app_instance is not None
  260. # get a reference to all buttons
  261. nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
  262. increment_button = driver.find_element(By.ID, "increment")
  263. # get a reference to the counter
  264. counter = driver.find_element(By.ID, "counter")
  265. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  266. nested_async_with_self_button.click()
  267. assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
  268. increment_button.click()
  269. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
  270. def test_get_state(
  271. background_task: AppHarness,
  272. driver: WebDriver,
  273. token: str,
  274. ):
  275. """Test that get_state returns a state bound to the correct StateProxy.
  276. Args:
  277. background_task: harness for BackgroundTask app.
  278. driver: WebDriver instance.
  279. token: The token for the connected client.
  280. """
  281. assert background_task.app_instance is not None
  282. # get a reference to all buttons
  283. other_state_button = driver.find_element(By.ID, "increment-from-other-state")
  284. increment_button = driver.find_element(By.ID, "increment")
  285. # get a reference to the counter
  286. counter = driver.find_element(By.ID, "counter")
  287. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  288. other_state_button.click()
  289. assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
  290. increment_button.click()
  291. assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
  292. def test_yield_in_async_with_self(
  293. background_task: AppHarness,
  294. driver: WebDriver,
  295. token: str,
  296. ):
  297. """Test that yielding inside async with self does not disable mutability.
  298. Args:
  299. background_task: harness for BackgroundTask app.
  300. driver: WebDriver instance.
  301. token: The token for the connected client.
  302. """
  303. assert background_task.app_instance is not None
  304. # get a reference to all buttons
  305. yield_in_async_with_self_button = driver.find_element(
  306. By.ID, "yield-in-async-with-self"
  307. )
  308. # get a reference to the counter
  309. counter = driver.find_element(By.ID, "counter")
  310. assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
  311. yield_in_async_with_self_button.click()
  312. assert background_task._poll_for(lambda: counter.text == "2", timeout=5)