123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377 |
- """Test @rx.event(background=True) task functionality."""
- from typing import Generator
- import pytest
- from selenium.webdriver.common.by import By
- from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
- def BackgroundTask():
- """Test that background tasks work as expected."""
- import asyncio
- import pytest
- import reflex as rx
- from reflex.state import ImmutableStateError
- class State(rx.State):
- counter: int = 0
- _task_id: int = 0
- iterations: int = 10
- @rx.event(background=True)
- async def handle_event(self):
- async with self:
- self._task_id += 1
- for _ix in range(int(self.iterations)):
- async with self:
- self.counter += 1
- await asyncio.sleep(0.005)
- @rx.event(background=True)
- async def handle_event_yield_only(self):
- async with self:
- self._task_id += 1
- for ix in range(int(self.iterations)):
- if ix % 2 == 0:
- yield State.increment_arbitrary(1)
- else:
- yield State.increment()
- await asyncio.sleep(0.005)
- @rx.event
- def increment(self):
- self.counter += 1
- @rx.event(background=True)
- async def increment_arbitrary(self, amount: int):
- async with self:
- self.counter += int(amount)
- @rx.event
- def reset_counter(self):
- self.counter = 0
- @rx.event
- async def blocking_pause(self):
- await asyncio.sleep(0.02)
- @rx.event(background=True)
- async def non_blocking_pause(self):
- await asyncio.sleep(0.02)
- async def racy_task(self):
- async with self:
- self._task_id += 1
- for _ix in range(int(self.iterations)):
- async with self:
- self.counter += 1
- await asyncio.sleep(0.005)
- @rx.event(background=True)
- async def handle_racy_event(self):
- await asyncio.gather(
- self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
- )
- @rx.event(background=True)
- async def nested_async_with_self(self):
- async with self:
- self.counter += 1
- with pytest.raises(ImmutableStateError):
- async with self:
- self.counter += 1
- async def triple_count(self):
- third_state = await self.get_state(ThirdState)
- await third_state._triple_count()
- @rx.event(background=True)
- async def yield_in_async_with_self(self):
- async with self:
- self.counter += 1
- yield
- self.counter += 1
- class OtherState(rx.State):
- @rx.event(background=True)
- async def get_other_state(self):
- async with self:
- state = await self.get_state(State)
- state.counter += 1
- await state.triple_count()
- with pytest.raises(ImmutableStateError):
- await state.triple_count()
- with pytest.raises(ImmutableStateError):
- state.counter += 1
- async with state:
- state.counter += 1
- await state.triple_count()
- class ThirdState(rx.State):
- async def _triple_count(self):
- state = await self.get_state(State)
- state.counter *= 3
- def index() -> rx.Component:
- return rx.vstack(
- rx.input(
- id="token", value=State.router.session.client_token, is_read_only=True
- ),
- rx.heading(State.counter, id="counter"),
- rx.input(
- id="iterations",
- placeholder="Iterations",
- value=State.iterations.to_string(), # pyright: ignore [reportAttributeAccessIssue]
- on_change=State.set_iterations, # pyright: ignore [reportAttributeAccessIssue]
- ),
- rx.button(
- "Delayed Increment",
- on_click=State.handle_event,
- id="delayed-increment",
- ),
- rx.button(
- "Yield Increment",
- on_click=State.handle_event_yield_only,
- id="yield-increment",
- ),
- rx.button("Increment 1", on_click=State.increment, id="increment"),
- rx.button(
- "Blocking Pause",
- on_click=State.blocking_pause,
- id="blocking-pause",
- ),
- rx.button(
- "Non-Blocking Pause",
- on_click=State.non_blocking_pause,
- id="non-blocking-pause",
- ),
- rx.button(
- "Racy Increment (x4)",
- on_click=State.handle_racy_event,
- id="racy-increment",
- ),
- rx.button(
- "Nested Async with Self",
- on_click=State.nested_async_with_self,
- id="nested-async-with-self",
- ),
- rx.button(
- "Increment from OtherState",
- on_click=OtherState.get_other_state,
- id="increment-from-other-state",
- ),
- rx.button(
- "Yield in Async with Self",
- on_click=State.yield_in_async_with_self,
- id="yield-in-async-with-self",
- ),
- rx.button("Reset", on_click=State.reset_counter, id="reset"),
- )
- app = rx.App(_state=rx.State)
- app.add_page(index)
- @pytest.fixture(scope="module")
- def background_task(
- tmp_path_factory,
- ) -> Generator[AppHarness, None, None]:
- """Start BackgroundTask app at tmp_path via AppHarness.
- Args:
- tmp_path_factory: pytest tmp_path_factory fixture
- Yields:
- running AppHarness instance
- """
- with AppHarness.create(
- root=tmp_path_factory.mktemp("background_task"),
- app_source=BackgroundTask,
- ) as harness:
- yield harness
- @pytest.fixture
- def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
- """Get an instance of the browser open to the background_task app.
- Args:
- background_task: harness for BackgroundTask app
- Yields:
- WebDriver instance.
- """
- assert background_task.app_instance is not None, "app is not running"
- driver = background_task.frontend()
- try:
- yield driver
- finally:
- driver.quit()
- @pytest.fixture()
- def token(background_task: AppHarness, driver: WebDriver) -> str:
- """Get a function that returns the active token.
- Args:
- background_task: harness for BackgroundTask app.
- driver: WebDriver instance.
- Returns:
- The token for the connected client
- """
- assert background_task.app_instance is not None
- token_input = driver.find_element(By.ID, "token")
- assert token_input
- # wait for the backend connection to send the token
- token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
- assert token is not None
- return token
- def test_background_task(
- background_task: AppHarness,
- driver: WebDriver,
- token: str,
- ):
- """Test that background tasks work as expected.
- Args:
- background_task: harness for BackgroundTask app.
- driver: WebDriver instance.
- token: The token for the connected client.
- """
- assert background_task.app_instance is not None
- # get a reference to all buttons
- delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
- yield_increment_button = driver.find_element(By.ID, "yield-increment")
- increment_button = driver.find_element(By.ID, "increment")
- blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
- non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
- racy_increment_button = driver.find_element(By.ID, "racy-increment")
- driver.find_element(By.ID, "reset")
- # get a reference to the counter
- counter = driver.find_element(By.ID, "counter")
- # get a reference to the iterations input
- iterations_input = driver.find_element(By.ID, "iterations")
- # kick off background tasks
- iterations_input.clear()
- iterations_input.send_keys("50")
- delayed_increment_button.click()
- blocking_pause_button.click()
- delayed_increment_button.click()
- for _ in range(10):
- increment_button.click()
- blocking_pause_button.click()
- delayed_increment_button.click()
- delayed_increment_button.click()
- yield_increment_button.click()
- racy_increment_button.click()
- non_blocking_pause_button.click()
- yield_increment_button.click()
- blocking_pause_button.click()
- yield_increment_button.click()
- for _ in range(10):
- increment_button.click()
- yield_increment_button.click()
- blocking_pause_button.click()
- assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
- # all tasks should have exited and cleaned up
- assert background_task._poll_for(
- lambda: not background_task.app_instance._background_tasks # pyright: ignore [reportOptionalMemberAccess]
- )
- def test_nested_async_with_self(
- background_task: AppHarness,
- driver: WebDriver,
- token: str,
- ):
- """Test that nested async with self in the same coroutine raises Exception.
- Args:
- background_task: harness for BackgroundTask app.
- driver: WebDriver instance.
- token: The token for the connected client.
- """
- assert background_task.app_instance is not None
- # get a reference to all buttons
- nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
- increment_button = driver.find_element(By.ID, "increment")
- # get a reference to the counter
- counter = driver.find_element(By.ID, "counter")
- assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
- nested_async_with_self_button.click()
- assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
- increment_button.click()
- assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
- def test_get_state(
- background_task: AppHarness,
- driver: WebDriver,
- token: str,
- ):
- """Test that get_state returns a state bound to the correct StateProxy.
- Args:
- background_task: harness for BackgroundTask app.
- driver: WebDriver instance.
- token: The token for the connected client.
- """
- assert background_task.app_instance is not None
- # get a reference to all buttons
- other_state_button = driver.find_element(By.ID, "increment-from-other-state")
- increment_button = driver.find_element(By.ID, "increment")
- # get a reference to the counter
- counter = driver.find_element(By.ID, "counter")
- assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
- other_state_button.click()
- assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
- increment_button.click()
- assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
- def test_yield_in_async_with_self(
- background_task: AppHarness,
- driver: WebDriver,
- token: str,
- ):
- """Test that yielding inside async with self does not disable mutability.
- Args:
- background_task: harness for BackgroundTask app.
- driver: WebDriver instance.
- token: The token for the connected client.
- """
- assert background_task.app_instance is not None
- # get a reference to all buttons
- yield_in_async_with_self_button = driver.find_element(
- By.ID, "yield-in-async-with-self"
- )
- # get a reference to the counter
- counter = driver.find_element(By.ID, "counter")
- assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
- yield_in_async_with_self_button.click()
- assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
|