浏览代码

rx.background and StateManager.modify_state provides safe exclusive access to state (#1676)

Masen Furer 1 年之前
父节点
当前提交
351611ca25

+ 17 - 0
.github/workflows/integration_app_harness.yml

@@ -15,7 +15,23 @@ permissions:
 
 jobs:
   integration-app-harness:
+    strategy:
+      matrix:
+        state_manager: [ "redis", "memory" ]
     runs-on: ubuntu-latest
+    services:
+      # Label used to access the service container
+      redis:
+        image: ${{ matrix.state_manager == 'redis' && 'redis' || '' }}
+        # Set health checks to wait until redis has started
+        options: >-
+          --health-cmd "redis-cli ping"
+          --health-interval 10s
+          --health-timeout 5s
+          --health-retries 5
+        ports:
+          # Maps port 6379 on service container to the host
+          - 6379:6379
     steps:
     - uses: actions/checkout@v4
     - uses: ./.github/actions/setup_build_env
@@ -27,6 +43,7 @@ jobs:
     - name: Run app harness tests
       env:
         SCREENSHOT_DIR: /tmp/screenshots
+        REDIS_URL: ${{ matrix.state_manager == 'redis' && 'localhost:6379' || '' }}
       run: |
         poetry run pytest integration
     - uses: actions/upload-artifact@v3

+ 20 - 0
.github/workflows/unit_tests.yml

@@ -40,6 +40,20 @@ jobs:
           - os: windows-latest
             python-version: "3.8.10"
     runs-on: ${{ matrix.os }}
+    # Service containers to run with `runner-job`
+    services:
+      # Label used to access the service container
+      redis:
+        image: ${{ matrix.os == 'ubuntu-latest' && 'redis' || '' }}
+        # Set health checks to wait until redis has started
+        options: >-
+          --health-cmd "redis-cli ping"
+          --health-interval 10s
+          --health-timeout 5s
+          --health-retries 5
+        ports:
+          # Maps port 6379 on service container to the host
+          - 6379:6379
     steps:
     - uses: actions/checkout@v4
     - uses: ./.github/actions/setup_build_env
@@ -51,4 +65,10 @@ jobs:
       run: |
         export PYTHONUNBUFFERED=1
         poetry run pytest tests --cov --no-cov-on-fail --cov-report=
+    - name: Run unit tests w/ redis
+      if: ${{ matrix.os == 'ubuntu-latest' }}
+      run: |
+        export PYTHONUNBUFFERED=1
+        export REDIS_URL=localhost:6379
+        poetry run pytest tests --cov --no-cov-on-fail --cov-report=
     - run: poetry run coverage html

+ 6 - 6
.pre-commit-config.yaml

@@ -1,4 +1,10 @@
 repos:
+  - repo: https://github.com/psf/black
+    rev: 22.10.0
+    hooks:
+    - id: black
+      args: [integration, reflex, tests]
+
   - repo: https://github.com/charliermarsh/ruff-pre-commit
     rev: v0.0.244
     hooks:
@@ -17,9 +23,3 @@ repos:
     hooks:
     - id: darglint
       exclude: '^reflex/reflex.py'
-
-  - repo: https://github.com/psf/black
-    rev: 22.10.0
-    hooks:
-    - id: black
-      args: [integration, reflex, tests]

+ 214 - 0
integration/test_background_task.py

@@ -0,0 +1,214 @@
+"""Test @rx.background task functionality."""
+
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
+
+
+def BackgroundTask():
+    """Test that background tasks work as expected."""
+    import asyncio
+
+    import reflex as rx
+
+    class State(rx.State):
+        counter: int = 0
+        _task_id: int = 0
+        iterations: int = 10
+
+        @rx.background
+        async def handle_event(self):
+            async with self:
+                self._task_id += 1
+            for _ix in range(int(self.iterations)):
+                async with self:
+                    self.counter += 1
+                await asyncio.sleep(0.005)
+
+        @rx.background
+        async def handle_event_yield_only(self):
+            async with self:
+                self._task_id += 1
+            for ix in range(int(self.iterations)):
+                if ix % 2 == 0:
+                    yield State.increment_arbitrary(1)  # type: ignore
+                else:
+                    yield State.increment()  # type: ignore
+                await asyncio.sleep(0.005)
+
+        def increment(self):
+            self.counter += 1
+
+        @rx.background
+        async def increment_arbitrary(self, amount: int):
+            async with self:
+                self.counter += int(amount)
+
+        def reset_counter(self):
+            self.counter = 0
+
+        async def blocking_pause(self):
+            await asyncio.sleep(0.02)
+
+        @rx.background
+        async def non_blocking_pause(self):
+            await asyncio.sleep(0.02)
+
+        @rx.cached_var
+        def token(self) -> str:
+            return self.get_token()
+
+    def index() -> rx.Component:
+        return rx.vstack(
+            rx.input(id="token", value=State.token, is_read_only=True),
+            rx.heading(State.counter, id="counter"),
+            rx.input(
+                id="iterations",
+                placeholder="Iterations",
+                value=State.iterations.to_string(),  # type: ignore
+                on_change=State.set_iterations,  # type: ignore
+            ),
+            rx.button(
+                "Delayed Increment",
+                on_click=State.handle_event,
+                id="delayed-increment",
+            ),
+            rx.button(
+                "Yield Increment",
+                on_click=State.handle_event_yield_only,
+                id="yield-increment",
+            ),
+            rx.button("Increment 1", on_click=State.increment, id="increment"),
+            rx.button(
+                "Blocking Pause",
+                on_click=State.blocking_pause,
+                id="blocking-pause",
+            ),
+            rx.button(
+                "Non-Blocking Pause",
+                on_click=State.non_blocking_pause,
+                id="non-blocking-pause",
+            ),
+            rx.button("Reset", on_click=State.reset_counter, id="reset"),
+        )
+
+    app = rx.App(state=State)
+    app.add_page(index)
+    app.compile()
+
+
+@pytest.fixture(scope="session")
+def background_task(
+    tmp_path_factory,
+) -> Generator[AppHarness, None, None]:
+    """Start BackgroundTask app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path_factory: pytest tmp_path_factory fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path_factory.mktemp(f"background_task"),
+        app_source=BackgroundTask,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.fixture
+def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
+    """Get an instance of the browser open to the background_task app.
+
+    Args:
+        background_task: harness for BackgroundTask app
+
+    Yields:
+        WebDriver instance.
+    """
+    assert background_task.app_instance is not None, "app is not running"
+    driver = background_task.frontend()
+    try:
+        yield driver
+    finally:
+        driver.quit()
+
+
+@pytest.fixture()
+def token(background_task: AppHarness, driver: WebDriver) -> str:
+    """Get a function that returns the active token.
+
+    Args:
+        background_task: harness for BackgroundTask app.
+        driver: WebDriver instance.
+
+    Returns:
+        The token for the connected client
+    """
+    assert background_task.app_instance is not None
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
+    assert token is not None
+
+    return token
+
+
+def test_background_task(
+    background_task: AppHarness,
+    driver: WebDriver,
+    token: str,
+):
+    """Test that background tasks work as expected.
+
+    Args:
+        background_task: harness for BackgroundTask app.
+        driver: WebDriver instance.
+        token: The token for the connected client.
+    """
+    assert background_task.app_instance is not None
+
+    # get a reference to all buttons
+    delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
+    yield_increment_button = driver.find_element(By.ID, "yield-increment")
+    increment_button = driver.find_element(By.ID, "increment")
+    blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
+    non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
+    driver.find_element(By.ID, "reset")
+
+    # get a reference to the counter
+    counter = driver.find_element(By.ID, "counter")
+
+    # get a reference to the iterations input
+    iterations_input = driver.find_element(By.ID, "iterations")
+
+    # kick off background tasks
+    iterations_input.clear()
+    iterations_input.send_keys("50")
+    delayed_increment_button.click()
+    blocking_pause_button.click()
+    delayed_increment_button.click()
+    for _ in range(10):
+        increment_button.click()
+    blocking_pause_button.click()
+    delayed_increment_button.click()
+    delayed_increment_button.click()
+    yield_increment_button.click()
+    non_blocking_pause_button.click()
+    yield_increment_button.click()
+    blocking_pause_button.click()
+    yield_increment_button.click()
+    for _ in range(10):
+        increment_button.click()
+    yield_increment_button.click()
+    blocking_pause_button.click()
+    assert background_task._poll_for(lambda: counter.text == "420", timeout=40)
+    # all tasks should have exited and cleaned up
+    assert background_task._poll_for(
+        lambda: not background_task.app_instance.background_tasks  # type: ignore
+    )

+ 23 - 14
integration/test_client_storage.py

@@ -133,7 +133,6 @@ def driver(client_side: AppHarness) -> Generator[WebDriver, None, None]:
     assert client_side.app_instance is not None, "app is not running"
     driver = client_side.frontend()
     try:
-        assert client_side.poll_for_clients()
         yield driver
     finally:
         driver.quit()
@@ -168,7 +167,20 @@ def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]:
     driver.delete_all_cookies()
 
 
-def test_client_side_state(
+def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]:
+    """Get a map of cookie names to cookie info.
+
+    Args:
+        driver: WebDriver instance.
+
+    Returns:
+        A map of cookie names to cookie info.
+    """
+    return {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
+
+
+@pytest.mark.asyncio
+async def test_client_side_state(
     client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage
 ):
     """Test client side state.
@@ -187,8 +199,6 @@ def test_client_side_state(
     token = client_side.poll_for_value(token_input)
     assert token is not None
 
-    backend_state = client_side.app_instance.state_manager.states[token]
-
     # get a reference to the cookie manipulation form
     state_var_input = driver.find_element(By.ID, "state_var")
     input_value_input = driver.find_element(By.ID, "input_value")
@@ -274,7 +284,7 @@ def test_client_side_state(
     input_value_input.send_keys("l1s value")
     set_sub_sub_state_button.click()
 
-    cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
+    cookies = cookie_info_map(driver)
     assert cookies.pop("client_side_state.client_side_sub_state.c1") == {
         "domain": "localhost",
         "httpOnly": False,
@@ -338,8 +348,10 @@ def test_client_side_state(
     state_var_input.send_keys("c3")
     input_value_input.send_keys("c3 value")
     set_sub_state_button.click()
-    cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
-    c3_cookie = cookies["client_side_state.client_side_sub_state.c3"]
+    AppHarness._poll_for(
+        lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
+    )
+    c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
     assert c3_cookie.pop("expiry") is not None
     assert c3_cookie == {
         "domain": "localhost",
@@ -351,9 +363,7 @@ def test_client_side_state(
         "value": "c3%20value",
     }
     time.sleep(2)  # wait for c3 to expire
-    assert "client_side_state.client_side_sub_state.c3" not in {
-        cookie_info["name"] for cookie_info in driver.get_cookies()
-    }
+    assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
 
     local_storage_items = local_storage.items()
     local_storage_items.pop("chakra-ui-color-mode", None)
@@ -426,7 +436,8 @@ def test_client_side_state(
     assert l1s.text == "l1s value"
 
     # reset the backend state to force refresh from client storage
-    backend_state.reset()
+    async with client_side.modify_state(token) as state:
+        state.reset()
     driver.refresh()
 
     # wait for the backend connection to send the token (again)
@@ -465,9 +476,7 @@ def test_client_side_state(
     assert l1s.text == "l1s value"
 
     # make sure c5 cookie shows up on the `/foo` route
-    cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
-
-    assert cookies["client_side_state.client_side_sub_state.c5"] == {
+    assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
         "domain": "localhost",
         "httpOnly": False,
         "name": "client_side_state.client_side_sub_state.c5",

+ 37 - 32
integration/test_dynamic_routes.py

@@ -1,11 +1,10 @@
 """Integration tests for dynamic route page behavior."""
-from typing import Callable, Generator, Type
+from typing import Callable, Coroutine, Generator, Type
 from urllib.parse import urlsplit
 
 import pytest
 from selenium.webdriver.common.by import By
 
-from reflex import State
 from reflex.testing import AppHarness, AppHarnessProd, WebDriver
 
 from .utils import poll_for_navigation
@@ -100,22 +99,21 @@ def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]:
     assert dynamic_route.app_instance is not None, "app is not running"
     driver = dynamic_route.frontend()
     try:
-        assert dynamic_route.poll_for_clients()
         yield driver
     finally:
         driver.quit()
 
 
 @pytest.fixture()
-def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
-    """Get the backend state.
+def token(dynamic_route: AppHarness, driver: WebDriver) -> str:
+    """Get the token associated with backend state.
 
     Args:
         dynamic_route: harness for DynamicRoute app.
         driver: WebDriver instance.
 
     Returns:
-        The backend state associated with the token visible in the driver browser.
+        The token visible in the driver browser.
     """
     assert dynamic_route.app_instance is not None
     token_input = driver.find_element(By.ID, "token")
@@ -125,43 +123,49 @@ def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
     token = dynamic_route.poll_for_value(token_input)
     assert token is not None
 
-    # look up the backend state from the state manager
-    return dynamic_route.app_instance.state_manager.states[token]
+    return token
 
 
 @pytest.fixture()
 def poll_for_order(
-    dynamic_route: AppHarness, backend_state: State
-) -> Callable[[list[str]], None]:
+    dynamic_route: AppHarness, token: str
+) -> Callable[[list[str]], Coroutine[None, None, None]]:
     """Poll for the order list to match the expected order.
 
     Args:
         dynamic_route: harness for DynamicRoute app.
-        backend_state: The backend state associated with the token visible in the driver browser.
+        token: The token visible in the driver browser.
 
     Returns:
-        A function that polls for the order list to match the expected order.
+        An async function that polls for the order list to match the expected order.
     """
 
-    def _poll_for_order(exp_order: list[str]):
-        dynamic_route._poll_for(lambda: backend_state.order == exp_order)
-        assert backend_state.order == exp_order
+    async def _poll_for_order(exp_order: list[str]):
+        async def _backend_state():
+            return await dynamic_route.get_state(token)
+
+        async def _check():
+            return (await _backend_state()).order == exp_order
+
+        await AppHarness._poll_for_async(_check)
+        assert (await _backend_state()).order == exp_order
 
     return _poll_for_order
 
 
-def test_on_load_navigate(
+@pytest.mark.asyncio
+async def test_on_load_navigate(
     dynamic_route: AppHarness,
     driver: WebDriver,
-    backend_state: State,
-    poll_for_order: Callable[[list[str]], None],
+    token: str,
+    poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
 ):
     """Click links to navigate between dynamic pages with on_load event.
 
     Args:
         dynamic_route: harness for DynamicRoute app.
         driver: WebDriver instance.
-        backend_state: The backend state associated with the token visible in the driver browser.
+        token: The token visible in the driver browser.
         poll_for_order: function that polls for the order list to match the expected order.
     """
     assert dynamic_route.app_instance is not None
@@ -184,7 +188,7 @@ def test_on_load_navigate(
         assert page_id_input
 
         assert dynamic_route.poll_for_value(page_id_input) == str(ix)
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
     # manually load the next page to trigger client side routing in prod mode
     if is_prod:
@@ -192,14 +196,14 @@ def test_on_load_navigate(
     exp_order += ["/page/[page_id]-10"]
     with poll_for_navigation(driver):
         driver.get(f"{dynamic_route.frontend_url}/page/10/")
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
     # make sure internal nav still hydrates after redirect
     exp_order += ["/page/[page_id]-11"]
     link = driver.find_element(By.ID, "link_page_next")
     with poll_for_navigation(driver):
         link.click()
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
     # load same page with a query param and make sure it passes through
     if is_prod:
@@ -207,14 +211,14 @@ def test_on_load_navigate(
     exp_order += ["/page/[page_id]-11"]
     with poll_for_navigation(driver):
         driver.get(f"{driver.current_url}?foo=bar")
-    poll_for_order(exp_order)
-    assert backend_state.get_query_params()["foo"] == "bar"
+    await poll_for_order(exp_order)
+    assert (await dynamic_route.get_state(token)).get_query_params()["foo"] == "bar"
 
     # hit a 404 and ensure we still hydrate
     exp_order += ["/404-no page id"]
     with poll_for_navigation(driver):
         driver.get(f"{dynamic_route.frontend_url}/missing")
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
     # browser nav should still trigger hydration
     if is_prod:
@@ -222,14 +226,14 @@ def test_on_load_navigate(
     exp_order += ["/page/[page_id]-11"]
     with poll_for_navigation(driver):
         driver.back()
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
     # next/link to a 404 and ensure we still hydrate
     exp_order += ["/404-no page id"]
     link = driver.find_element(By.ID, "link_missing")
     with poll_for_navigation(driver):
         link.click()
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
     # hit a page that redirects back to dynamic page
     if is_prod:
@@ -237,15 +241,16 @@ def test_on_load_navigate(
     exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"]
     with poll_for_navigation(driver):
         driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar")
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
     # should have redirected back to page 0
     assert urlsplit(driver.current_url).path == "/page/0/"
 
 
-def test_on_load_navigate_non_dynamic(
+@pytest.mark.asyncio
+async def test_on_load_navigate_non_dynamic(
     dynamic_route: AppHarness,
     driver: WebDriver,
-    poll_for_order: Callable[[list[str]], None],
+    poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
 ):
     """Click links to navigate between static pages with on_load event.
 
@@ -261,7 +266,7 @@ def test_on_load_navigate_non_dynamic(
     with poll_for_navigation(driver):
         link.click()
     assert urlsplit(driver.current_url).path == "/static/x/"
-    poll_for_order(["/static/x-no page id"])
+    await poll_for_order(["/static/x-no page id"])
 
     # go back to the index and navigate back to the static route
     link = driver.find_element(By.ID, "link_index")
@@ -273,4 +278,4 @@ def test_on_load_navigate_non_dynamic(
     with poll_for_navigation(driver):
         link.click()
     assert urlsplit(driver.current_url).path == "/static/x/"
-    poll_for_order(["/static/x-no page id", "/static/x-no page id"])
+    await poll_for_order(["/static/x-no page id", "/static/x-no page id"])

+ 110 - 22
integration/test_event_chain.py

@@ -1,18 +1,20 @@
 """Ensure that Event Chains are properly queued and handled between frontend and backend."""
 
-import time
 from typing import Generator
 
 import pytest
 from selenium.webdriver.common.by import By
 
-from reflex.testing import AppHarness
+from reflex.testing import AppHarness, WebDriver
 
 MANY_EVENTS = 50
 
 
 def EventChain():
     """App with chained event handlers."""
+    import asyncio
+    import time
+
     import reflex as rx
 
     # repeated here since the outer global isn't exported into the App module
@@ -20,6 +22,7 @@ def EventChain():
 
     class State(rx.State):
         event_order: list[str] = []
+        interim_value: str = ""
 
         @rx.var
         def token(self) -> str:
@@ -111,12 +114,25 @@ def EventChain():
             self.event_order.append("click_return_dict_type")
             return State.event_arg_repr_type({"a": 1})  # type: ignore
 
+        async def click_yield_interim_value_async(self):
+            self.interim_value = "interim"
+            yield
+            await asyncio.sleep(0.5)
+            self.interim_value = "final"
+
+        def click_yield_interim_value(self):
+            self.interim_value = "interim"
+            yield
+            time.sleep(0.5)
+            self.interim_value = "final"
+
     app = rx.App(state=State)
 
     @app.add_page
     def index():
         return rx.fragment(
             rx.input(value=State.token, readonly=True, id="token"),
+            rx.input(value=State.interim_value, readonly=True, id="interim_value"),
             rx.button(
                 "Return Event",
                 id="return_event",
@@ -172,6 +188,16 @@ def EventChain():
                 id="return_dict_type",
                 on_click=State.click_return_dict_type,
             ),
+            rx.button(
+                "Click Yield Interim Value (Async)",
+                id="click_yield_interim_value_async",
+                on_click=State.click_yield_interim_value_async,
+            ),
+            rx.button(
+                "Click Yield Interim Value",
+                id="click_yield_interim_value",
+                on_click=State.click_yield_interim_value,
+            ),
         )
 
     def on_load_return_chain():
@@ -237,7 +263,7 @@ def event_chain(tmp_path_factory) -> Generator[AppHarness, None, None]:
 
 
 @pytest.fixture
-def driver(event_chain: AppHarness):
+def driver(event_chain: AppHarness) -> Generator[WebDriver, None, None]:
     """Get an instance of the browser open to the event_chain app.
 
     Args:
@@ -249,7 +275,6 @@ def driver(event_chain: AppHarness):
     assert event_chain.app_instance is not None, "app is not running"
     driver = event_chain.frontend()
     try:
-        assert event_chain.poll_for_clients()
         yield driver
     finally:
         driver.quit()
@@ -335,7 +360,13 @@ def driver(event_chain: AppHarness):
         ),
     ],
 )
-def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
+@pytest.mark.asyncio
+async def test_event_chain_click(
+    event_chain: AppHarness,
+    driver: WebDriver,
+    button_id: str,
+    exp_event_order: list[str],
+):
     """Click the button, assert that the events are handled in the correct order.
 
     Args:
@@ -350,17 +381,18 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
     assert btn
 
     token = event_chain.poll_for_value(token_input)
+    assert token is not None
 
     btn.click()
-    if "redirect" in button_id:
-        # wait a bit longer if we're redirecting
-        time.sleep(1)
-    if "many_events" in button_id:
-        # wait a bit longer if we have loads of events
-        time.sleep(1)
-    time.sleep(0.5)
-    backend_state = event_chain.app_instance.state_manager.states[token]
-    assert backend_state.event_order == exp_event_order
+
+    async def _has_all_events():
+        return len((await event_chain.get_state(token)).event_order) == len(
+            exp_event_order
+        )
+
+    await AppHarness._poll_for_async(_has_all_events)
+    event_order = (await event_chain.get_state(token)).event_order
+    assert event_order == exp_event_order
 
 
 @pytest.mark.parametrize(
@@ -386,7 +418,13 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
         ),
     ],
 )
-def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
+@pytest.mark.asyncio
+async def test_event_chain_on_load(
+    event_chain: AppHarness,
+    driver: WebDriver,
+    uri: str,
+    exp_event_order: list[str],
+):
     """Load the URI, assert that the events are handled in the correct order.
 
     Args:
@@ -395,16 +433,23 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
         uri: the page to load
         exp_event_order: the expected events recorded in the State
     """
+    assert event_chain.frontend_url is not None
     driver.get(event_chain.frontend_url + uri)
     token_input = driver.find_element(By.ID, "token")
     assert token_input
 
     token = event_chain.poll_for_value(token_input)
+    assert token is not None
 
-    time.sleep(0.5)
-    backend_state = event_chain.app_instance.state_manager.states[token]
-    assert backend_state.is_hydrated is True
+    async def _has_all_events():
+        return len((await event_chain.get_state(token)).event_order) == len(
+            exp_event_order
+        )
+
+    await AppHarness._poll_for_async(_has_all_events)
+    backend_state = await event_chain.get_state(token)
     assert backend_state.event_order == exp_event_order
+    assert backend_state.is_hydrated is True
 
 
 @pytest.mark.parametrize(
@@ -444,7 +489,13 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
         ),
     ],
 )
-def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
+@pytest.mark.asyncio
+async def test_event_chain_on_mount(
+    event_chain: AppHarness,
+    driver: WebDriver,
+    uri: str,
+    exp_event_order: list[str],
+):
     """Load the URI, assert that the events are handled in the correct order.
 
     These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
@@ -458,16 +509,53 @@ def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
         uri: the page to load
         exp_event_order: the expected events recorded in the State
     """
+    assert event_chain.frontend_url is not None
     driver.get(event_chain.frontend_url + uri)
     token_input = driver.find_element(By.ID, "token")
     assert token_input
 
     token = event_chain.poll_for_value(token_input)
+    assert token is not None
 
     unmount_button = driver.find_element(By.ID, "unmount")
     assert unmount_button
     unmount_button.click()
 
-    time.sleep(1)
-    backend_state = event_chain.app_instance.state_manager.states[token]
-    assert backend_state.event_order == exp_event_order
+    async def _has_all_events():
+        return len((await event_chain.get_state(token)).event_order) == len(
+            exp_event_order
+        )
+
+    await AppHarness._poll_for_async(_has_all_events)
+    event_order = (await event_chain.get_state(token)).event_order
+    assert event_order == exp_event_order
+
+
+@pytest.mark.parametrize(
+    ("button_id",),
+    [
+        ("click_yield_interim_value_async",),
+        ("click_yield_interim_value",),
+    ],
+)
+def test_yield_state_update(event_chain: AppHarness, driver: WebDriver, button_id: str):
+    """Click the button, assert that the interim value is set, then final value is set.
+
+    Args:
+        event_chain: AppHarness for the event_chain app
+        driver: selenium WebDriver open to the app
+        button_id: the ID of the button to click
+    """
+    token_input = driver.find_element(By.ID, "token")
+    interim_value_input = driver.find_element(By.ID, "interim_value")
+    assert event_chain.poll_for_value(token_input)
+
+    btn = driver.find_element(By.ID, button_id)
+    btn.click()
+    assert (
+        event_chain.poll_for_value(interim_value_input, exp_not_equal="") == "interim"
+    )
+    assert (
+        event_chain.poll_for_value(interim_value_input, exp_not_equal="interim")
+        == "final"
+    )

+ 32 - 18
integration/test_form_submit.py

@@ -19,11 +19,16 @@ def FormSubmit():
         def form_submit(self, form_data: dict):
             self.form_data = form_data
 
+        @rx.var
+        def token(self) -> str:
+            return self.get_token()
+
     app = rx.App(state=FormState)
 
     @app.add_page
     def index():
         return rx.vstack(
+            rx.input(value=FormState.token, is_read_only=True, id="token"),
             rx.form(
                 rx.vstack(
                     rx.input(id="name_input"),
@@ -82,13 +87,13 @@ def driver(form_submit: AppHarness):
     """
     driver = form_submit.frontend()
     try:
-        assert form_submit.poll_for_clients()
         yield driver
     finally:
         driver.quit()
 
 
-def test_submit(driver, form_submit: AppHarness):
+@pytest.mark.asyncio
+async def test_submit(driver, form_submit: AppHarness):
     """Fill a form with various different output, submit it to backend and verify
     the output.
 
@@ -97,7 +102,14 @@ def test_submit(driver, form_submit: AppHarness):
         form_submit: harness for FormSubmit app
     """
     assert form_submit.app_instance is not None, "app is not running"
-    _, backend_state = list(form_submit.app_instance.state_manager.states.items())[0]
+
+    # get a reference to the connected client
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = form_submit.poll_for_value(token_input)
+    assert token
 
     name_input = driver.find_element(By.ID, "name_input")
     name_input.send_keys("foo")
@@ -132,19 +144,21 @@ def test_submit(driver, form_submit: AppHarness):
     submit_input = driver.find_element(By.CLASS_NAME, "chakra-button")
     submit_input.click()
 
+    async def get_form_data():
+        return (await form_submit.get_state(token)).form_data
+
     # wait for the form data to arrive at the backend
-    AppHarness._poll_for(
-        lambda: backend_state.form_data != {},
-    )
-
-    assert backend_state.form_data["name_input"] == "foo"
-    assert backend_state.form_data["pin_input"] == pin_values
-    assert backend_state.form_data["number_input"] == "-3"
-    assert backend_state.form_data["bool_input"] is True
-    assert backend_state.form_data["bool_input2"] is True
-    assert backend_state.form_data["slider_input"] == "50"
-    assert backend_state.form_data["range_input"] == ["25", "75"]
-    assert backend_state.form_data["radio_input"] == "option2"
-    assert backend_state.form_data["select_input"] == "option1"
-    assert backend_state.form_data["text_area_input"] == "Some\nText"
-    assert backend_state.form_data["debounce_input"] == "bar baz"
+    form_data = await AppHarness._poll_for_async(get_form_data)
+    assert isinstance(form_data, dict)
+
+    assert form_data["name_input"] == "foo"
+    assert form_data["pin_input"] == pin_values
+    assert form_data["number_input"] == "-3"
+    assert form_data["bool_input"] is True
+    assert form_data["bool_input2"] is True
+    assert form_data["slider_input"] == "50"
+    assert form_data["range_input"] == ["25", "75"]
+    assert form_data["radio_input"] == "option2"
+    assert form_data["select_input"] == "option1"
+    assert form_data["text_area_input"] == "Some\nText"
+    assert form_data["debounce_input"] == "bar baz"

+ 19 - 11
integration/test_input.py

@@ -16,11 +16,16 @@ def FullyControlledInput():
     class State(rx.State):
         text: str = "initial"
 
+        @rx.var
+        def token(self) -> str:
+            return self.get_token()
+
     app = rx.App(state=State)
 
     @app.add_page
     def index():
         return rx.fragment(
+            rx.input(value=State.token, is_read_only=True, id="token"),
             rx.input(
                 id="debounce_input_input",
                 on_change=State.set_text,  # type: ignore
@@ -62,10 +67,12 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     driver = fully_controlled_input.frontend()
 
     # get a reference to the connected client
-    assert len(fully_controlled_input.poll_for_clients()) == 1
-    token, backend_state = list(
-        fully_controlled_input.app_instance.state_manager.states.items()
-    )[0]
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = fully_controlled_input.poll_for_value(token_input)
+    assert token
 
     # find the input and wait for it to have the initial state value
     debounce_input = driver.find_element(By.ID, "debounce_input_input")
@@ -80,14 +87,13 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("foo")
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "ifoonitial"
-    assert backend_state.text == "ifoonitial"
+    assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
     assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
 
     # clear the input on the backend
-    backend_state.text = ""
-    fully_controlled_input.app_instance.state_manager.set_state(token, backend_state)
-    await fully_controlled_input.emit_state_updates()
-    assert backend_state.text == ""
+    async with fully_controlled_input.modify_state(token) as state:
+        state.text = ""
+    assert (await fully_controlled_input.get_state(token)).text == ""
     assert (
         fully_controlled_input.poll_for_value(
             debounce_input, exp_not_equal="ifoonitial"
@@ -99,7 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("getting testing done")
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "getting testing done"
-    assert backend_state.text == "getting testing done"
+    assert (
+        await fully_controlled_input.get_state(token)
+    ).text == "getting testing done"
     assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
 
     # type into the on_change input
@@ -107,7 +115,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "overwrite the state"
     assert on_change_input.get_attribute("value") == "overwrite the state"
-    assert backend_state.text == "overwrite the state"
+    assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
     assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
 
     clear_button.click()

+ 11 - 1
integration/test_server_side_event.py

@@ -33,11 +33,16 @@ def ServerSideEvent():
         def set_value_return_c(self):
             return rx.set_value("c", "")
 
+        @rx.var
+        def token(self) -> str:
+            return self.get_token()
+
     app = rx.App(state=SSState)
 
     @app.add_page
     def index():
         return rx.fragment(
+            rx.input(id="token", value=SSState.token, is_read_only=True),
             rx.input(default_value="a", id="a"),
             rx.input(default_value="b", id="b"),
             rx.input(default_value="c", id="c"),
@@ -106,7 +111,12 @@ def driver(server_side_event: AppHarness):
     assert server_side_event.app_instance is not None, "app is not running"
     driver = server_side_event.frontend()
     try:
-        assert server_side_event.poll_for_clients()
+        token_input = driver.find_element(By.ID, "token")
+        assert token_input
+        # wait for the backend connection to send the token
+        token = server_side_event.poll_for_value(token_input)
+        assert token is not None
+
         yield driver
     finally:
         driver.quit()

+ 16 - 9
integration/test_upload.py

@@ -89,13 +89,13 @@ def driver(upload_file: AppHarness):
     assert upload_file.app_instance is not None, "app is not running"
     driver = upload_file.frontend()
     try:
-        assert upload_file.poll_for_clients()
         yield driver
     finally:
         driver.quit()
 
 
-def test_upload_file(tmp_path, upload_file: AppHarness, driver):
+@pytest.mark.asyncio
+async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
     """Submit a file upload and check that it arrived on the backend.
 
     Args:
@@ -124,16 +124,20 @@ def test_upload_file(tmp_path, upload_file: AppHarness, driver):
     upload_button.click()
 
     # look up the backend state and assert on uploaded contents
-    backend_state = upload_file.app_instance.state_manager.states[token]
-    time.sleep(0.5)
-    assert backend_state._file_data[exp_name] == exp_contents
+    async def get_file_data():
+        return (await upload_file.get_state(token))._file_data
+
+    file_data = await AppHarness._poll_for_async(get_file_data)
+    assert isinstance(file_data, dict)
+    assert file_data[exp_name] == exp_contents
 
     # check that the selected files are displayed
     selected_files = driver.find_element(By.ID, "selected_files")
     assert selected_files.text == exp_name
 
 
-def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
+@pytest.mark.asyncio
+async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
     """Submit several file uploads and check that they arrived on the backend.
 
     Args:
@@ -173,10 +177,13 @@ def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
     upload_button.click()
 
     # look up the backend state and assert on uploaded contents
-    backend_state = upload_file.app_instance.state_manager.states[token]
-    time.sleep(0.5)
+    async def get_file_data():
+        return (await upload_file.get_state(token))._file_data
+
+    file_data = await AppHarness._poll_for_async(get_file_data)
+    assert isinstance(file_data, dict)
     for exp_name, exp_contents in exp_files.items():
-        assert backend_state._file_data[exp_name] == exp_contents
+        assert file_data[exp_name] == exp_contents
 
 
 def test_clear_files(tmp_path, upload_file: AppHarness, driver):

+ 11 - 1
integration/test_var_operations.py

@@ -26,11 +26,16 @@ def VarOperations():
         dict1: dict = {1: 2}
         dict2: dict = {3: 4}
 
+        @rx.var
+        def token(self) -> str:
+            return self.get_token()
+
     app = rx.App(state=VarOperationState)
 
     @app.add_page
     def index():
         return rx.vstack(
+            rx.input(id="token", value=VarOperationState.token, is_read_only=True),
             # INT INT
             rx.text(
                 VarOperationState.int_var1 + VarOperationState.int_var2,
@@ -544,7 +549,12 @@ def driver(var_operations: AppHarness):
     """
     driver = var_operations.frontend()
     try:
-        assert var_operations.poll_for_clients()
+        token_input = driver.find_element(By.ID, "token")
+        assert token_input
+        # wait for the backend connection to send the token
+        token = var_operations.poll_for_value(token_input)
+        assert token is not None
+
         yield driver
     finally:
         driver.quit()

+ 1 - 0
reflex/__init__.py

@@ -21,6 +21,7 @@ from .constants import Env as Env
 from .event import EVENT_ARG as EVENT_ARG
 from .event import EventChain as EventChain
 from .event import FileUpload as upload_files
+from .event import background as background
 from .event import clear_local_storage as clear_local_storage
 from .event import console_log as console_log
 from .event import download as download

+ 164 - 73
reflex/app.py

@@ -2,6 +2,7 @@
 from __future__ import annotations
 
 import asyncio
+import contextlib
 import inspect
 import os
 from multiprocessing.pool import ThreadPool
@@ -13,6 +14,7 @@ from typing import (
     Dict,
     List,
     Optional,
+    Set,
     Type,
     Union,
 )
@@ -49,7 +51,13 @@ from reflex.route import (
     get_route_args,
     verify_route_validity,
 )
-from reflex.state import DefaultState, State, StateManager, StateUpdate
+from reflex.state import (
+    DefaultState,
+    State,
+    StateManager,
+    StateManagerMemory,
+    StateUpdate,
+)
 from reflex.utils import console, format, prerequisites, types
 from reflex.vars import ImportVar
 
@@ -89,7 +97,7 @@ class App(Base):
     state: Type[State] = DefaultState
 
     # Class to manage many client states.
-    state_manager: StateManager = StateManager()
+    state_manager: StateManager = StateManagerMemory(state=DefaultState)
 
     # The styling to apply to each component.
     style: ComponentStyle = {}
@@ -104,13 +112,16 @@ class App(Base):
     admin_dash: Optional[AdminDash] = None
 
     # The async server name space
-    event_namespace: Optional[AsyncNamespace] = None
+    event_namespace: Optional[EventNamespace] = None
 
     # A component that is present on every page.
     overlay_component: Optional[
         Union[Component, ComponentCallable]
     ] = default_overlay_component
 
+    # Background tasks that are currently running
+    background_tasks: Set[asyncio.Task] = set()
+
     def __init__(self, *args, **kwargs):
         """Initialize the app.
 
@@ -154,7 +165,7 @@ class App(Base):
         self.middleware.append(HydrateMiddleware())
 
         # Set up the state manager.
-        self.state_manager.setup(state=self.state)
+        self.state_manager = StateManager.create(state=self.state)
 
         # Set up the API.
         self.api = FastAPI()
@@ -646,6 +657,76 @@ class App(Base):
         thread_pool.close()
         thread_pool.join()
 
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state out of band.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state to modify.
+
+        Raises:
+            RuntimeError: If the app has not been initialized yet.
+        """
+        if self.event_namespace is None:
+            raise RuntimeError("App has not been initialized yet.")
+        # Get exclusive access to the state.
+        async with self.state_manager.modify_state(token) as state:
+            # No other event handler can modify the state while in this context.
+            yield state
+            delta = state.get_delta()
+            if delta:
+                # When the state is modified reset dirty status and emit the delta to the frontend.
+                state._clean()
+                await self.event_namespace.emit_update(
+                    update=StateUpdate(delta=delta),
+                    sid=state.get_sid(),
+                )
+
+    def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
+        """Process an event in the background and emit updates as they arrive.
+
+        Args:
+            state: The state to process the event for.
+            event: The event to process.
+
+        Returns:
+            Task if the event was backgroundable, otherwise None
+        """
+        substate, handler = state._get_event_handler(event)
+        if not handler.is_background:
+            return None
+
+        async def _coro():
+            """Coroutine to process the event and emit updates inside an asyncio.Task.
+
+            Raises:
+                RuntimeError: If the app has not been initialized yet.
+            """
+            if self.event_namespace is None:
+                raise RuntimeError("App has not been initialized yet.")
+
+            # Process the event.
+            async for update in state._process_event(
+                handler=handler, state=substate, payload=event.payload
+            ):
+                # Postprocess the event.
+                update = await self.postprocess(state, event, update)
+
+                # Send the update to the client.
+                await self.event_namespace.emit_update(
+                    update=update,
+                    sid=state.get_sid(),
+                )
+
+        task = asyncio.create_task(_coro())
+        self.background_tasks.add(task)
+        # Clean up task from background_tasks set when complete.
+        task.add_done_callback(self.background_tasks.discard)
+        return task
+
 
 async def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
@@ -662,9 +743,6 @@ async def process(
     Yields:
         The state updates after processing the event.
     """
-    # Get the state for the session.
-    state = app.state_manager.get_state(event.token)
-
     # Add request data to the state.
     router_data = event.router_data
     router_data.update(
@@ -676,31 +754,35 @@ async def process(
             constants.RouteVar.CLIENT_IP: client_ip,
         }
     )
-    # re-assign only when the value is different
-    if state.router_data != router_data:
-        # assignment will recurse into substates and force recalculation of
-        # dependent ComputedVar (dynamic route variables)
-        state.router_data = router_data
-
-    # Preprocess the event.
-    update = await app.preprocess(state, event)
-
-    # If there was an update, yield it.
-    if update is not None:
-        yield update
-
-    # Only process the event if there is no update.
-    else:
-        # Process the event.
-        async for update in state._process(event):
-            # Postprocess the event.
-            update = await app.postprocess(state, event, update)
-
-            # Yield the update.
+    # Get the state for the session exclusively.
+    async with app.state_manager.modify_state(event.token) as state:
+        # re-assign only when the value is different
+        if state.router_data != router_data:
+            # assignment will recurse into substates and force recalculation of
+            # dependent ComputedVar (dynamic route variables)
+            state.router_data = router_data
+
+        # Preprocess the event.
+        update = await app.preprocess(state, event)
+
+        # If there was an update, yield it.
+        if update is not None:
             yield update
 
-    # Set the state for the session.
-    app.state_manager.set_state(event.token, state)
+        # Only process the event if there is no update.
+        else:
+            if app._process_background(state, event) is not None:
+                # `final=True` allows the frontend send more events immediately.
+                yield StateUpdate(final=True)
+                return
+
+            # Process the event synchronously.
+            async for update in state._process(event):
+                # Postprocess the event.
+                update = await app.postprocess(state, event, update)
+
+                # Yield the update.
+                yield update
 
 
 async def ping() -> str:
@@ -737,47 +819,46 @@ def upload(app: App):
             assert file.filename is not None
             file.filename = file.filename.split(":")[-1]
         # Get the state for the session.
-        state = app.state_manager.get_state(token)
-        # get the current session ID
-        sid = state.get_sid()
-        # get the current state(parent state/substate)
-        path = handler.split(".")[:-1]
-        current_state = state.get_substate(path)
-        handler_upload_param = ()
-
-        # get handler function
-        func = getattr(current_state, handler.split(".")[-1])
-
-        # check if there exists any handler args with annotation, List[UploadFile]
-        for k, v in inspect.getfullargspec(
-            func.fn if isinstance(func, EventHandler) else func
-        ).annotations.items():
-            if types.is_generic_alias(v) and types._issubclass(
-                v.__args__[0], UploadFile
-            ):
-                handler_upload_param = (k, v)
-                break
+        async with app.state_manager.modify_state(token) as state:
+            # get the current session ID
+            sid = state.get_sid()
+            # get the current state(parent state/substate)
+            path = handler.split(".")[:-1]
+            current_state = state.get_substate(path)
+            handler_upload_param = ()
+
+            # get handler function
+            func = getattr(current_state, handler.split(".")[-1])
+
+            # check if there exists any handler args with annotation, List[UploadFile]
+            for k, v in inspect.getfullargspec(
+                func.fn if isinstance(func, EventHandler) else func
+            ).annotations.items():
+                if types.is_generic_alias(v) and types._issubclass(
+                    v.__args__[0], UploadFile
+                ):
+                    handler_upload_param = (k, v)
+                    break
+
+            if not handler_upload_param:
+                raise ValueError(
+                    f"`{handler}` handler should have a parameter annotated as List["
+                    f"rx.UploadFile]"
+                )
 
-        if not handler_upload_param:
-            raise ValueError(
-                f"`{handler}` handler should have a parameter annotated as List["
-                f"rx.UploadFile]"
+            event = Event(
+                token=token,
+                name=handler,
+                payload={handler_upload_param[0]: files},
             )
-
-        event = Event(
-            token=token,
-            name=handler,
-            payload={handler_upload_param[0]: files},
-        )
-        async for update in state._process(event):
-            # Postprocess the event.
-            update = await app.postprocess(state, event, update)
-            # Send update to client
-            await asyncio.create_task(
-                app.event_namespace.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)  # type: ignore
-            )
-        # Set the state for the session.
-        app.state_manager.set_state(event.token, state)
+            async for update in state._process(event):
+                # Postprocess the event.
+                update = await app.postprocess(state, event, update)
+                # Send update to client
+                await app.event_namespace.emit_update(  # type: ignore
+                    update=update,
+                    sid=sid,
+                )
 
     return upload_file
 
@@ -815,6 +896,18 @@ class EventNamespace(AsyncNamespace):
         """
         pass
 
+    async def emit_update(self, update: StateUpdate, sid: str) -> None:
+        """Emit an update to the client.
+
+        Args:
+            update: The state update to send.
+            sid: The Socket.IO session id.
+        """
+        # Creating a task prevents the update from being blocked behind other coroutines.
+        await asyncio.create_task(
+            self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
+        )
+
     async def on_event(self, sid, data):
         """Event for receiving front-end websocket events.
 
@@ -841,10 +934,8 @@ class EventNamespace(AsyncNamespace):
 
         # Process the events.
         async for update in process(self.app, event, sid, headers, client_ip):
-            # Emit the event.
-            await asyncio.create_task(
-                self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
-            )
+            # Emit the update from processing the event.
+            await self.emit_update(update=update, sid=sid)
 
     async def on_ping(self, sid):
         """Event for testing the API endpoint.

+ 8 - 0
reflex/app.pyi

@@ -1,5 +1,6 @@
 """ Generated with stubgen from mypy, then manually edited, do not regen."""
 
+import asyncio
 from fastapi import FastAPI
 from fastapi import UploadFile as UploadFile
 from reflex import constants as constants
@@ -45,12 +46,14 @@ from reflex.utils import (
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
 from typing import (
     Any,
+    AsyncContextManager,
     AsyncIterator,
     Callable,
     Coroutine,
     Dict,
     List,
     Optional,
+    Set,
     Type,
     Union,
     overload,
@@ -75,6 +78,7 @@ class App(Base):
     admin_dash: Optional[AdminDash]
     event_namespace: Optional[AsyncNamespace]
     overlay_component: Optional[Union[Component, ComponentCallable]]
+    background_tasks: Set[asyncio.Task] = set()
     def __init__(
         self,
         *args,
@@ -116,6 +120,10 @@ class App(Base):
     def setup_admin_dash(self) -> None: ...
     def get_frontend_packages(self, imports: Dict[str, str]): ...
     def compile(self) -> None: ...
+    def modify_state(self, token: str) -> AsyncContextManager[State]: ...
+    def _process_background(
+        self, state: State, event: Event
+    ) -> asyncio.Task | None: ...
 
 async def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str

+ 2 - 0
reflex/constants.py

@@ -219,6 +219,8 @@ OLD_CONFIG_FILE = f"pcconfig{PY_EXT}"
 PRODUCTION_BACKEND_URL = "https://{username}-{app_name}.api.pynecone.app"
 # Token expiration time in seconds.
 TOKEN_EXPIRATION = 60 * 60
+# Maximum time in milliseconds that a state can be locked for exclusive access.
+LOCK_EXPIRATION = 10000
 
 # Testing variables.
 # Testing os env set by pytest when running a test case.

+ 84 - 2
reflex/event.py

@@ -2,7 +2,17 @@
 from __future__ import annotations
 
 import inspect
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
 
 from reflex import constants
 from reflex.base import Base
@@ -10,6 +20,9 @@ from reflex.utils import console, format
 from reflex.utils.types import ArgsSpec
 from reflex.vars import BaseVar, Var
 
+if TYPE_CHECKING:
+    from reflex.state import State
+
 
 class Event(Base):
     """An event that describes any state change in the app."""
@@ -27,6 +40,66 @@ class Event(Base):
     payload: Dict[str, Any] = {}
 
 
+BACKGROUND_TASK_MARKER = "_reflex_background_task"
+
+
+def background(fn):
+    """Decorator to mark event handler as running in the background.
+
+    Args:
+        fn: The function to decorate.
+
+    Returns:
+        The same function, but with a marker set.
+
+
+    Raises:
+        TypeError: If the function is not a coroutine function or async generator.
+    """
+    if not inspect.iscoroutinefunction(fn) and not inspect.isasyncgenfunction(fn):
+        raise TypeError("Background task must be async function or generator.")
+    setattr(fn, BACKGROUND_TASK_MARKER, True)
+    return fn
+
+
+def _no_chain_background_task(
+    state_cls: Type["State"], name: str, fn: Callable
+) -> Callable:
+    """Protect against directly chaining a background task from another event handler.
+
+    Args:
+        state_cls: The state class that the event handler is in.
+        name: The name of the background task.
+        fn: The background task coroutine function / generator.
+
+    Returns:
+        A compatible coroutine function / generator that raises a runtime error.
+
+    Raises:
+        TypeError: If the background task is not async.
+    """
+    call = f"{state_cls.__name__}.{name}"
+    message = (
+        f"Cannot directly call background task {name!r}, use "
+        f"`yield {call}` or `return {call}` instead."
+    )
+    if inspect.iscoroutinefunction(fn):
+
+        async def _no_chain_background_task_co(*args, **kwargs):
+            raise RuntimeError(message)
+
+        return _no_chain_background_task_co
+    if inspect.isasyncgenfunction(fn):
+
+        async def _no_chain_background_task_gen(*args, **kwargs):
+            yield
+            raise RuntimeError(message)
+
+        return _no_chain_background_task_gen
+
+    raise TypeError(f"{fn} is marked as a background task, but is not async.")
+
+
 class EventHandler(Base):
     """An event handler responds to an event to update the state."""
 
@@ -39,6 +112,15 @@ class EventHandler(Base):
         # Needed to allow serialization of Callable.
         frozen = True
 
+    @property
+    def is_background(self) -> bool:
+        """Whether the event handler is a background task.
+
+        Returns:
+            True if the event handler is marked as a background task.
+        """
+        return getattr(self.fn, BACKGROUND_TASK_MARKER, False)
+
     def __call__(self, *args: Var) -> EventSpec:
         """Pass arguments to the handler to get an event spec.
 
@@ -530,7 +612,7 @@ def get_handler_args(event_spec: EventSpec) -> tuple[tuple[Var, Var], ...]:
 
 
 def fix_events(
-    events: list[EventHandler | EventSpec],
+    events: list[EventHandler | EventSpec] | None,
     token: str,
     router_data: dict[str, Any] | None = None,
 ) -> list[Event]:

+ 546 - 69
reflex/state.py

@@ -2,13 +2,15 @@
 from __future__ import annotations
 
 import asyncio
+import contextlib
 import copy
 import functools
 import inspect
 import json
 import traceback
 import urllib.parse
-from abc import ABC
+import uuid
+from abc import ABC, abstractmethod
 from collections import defaultdict
 from types import FunctionType
 from typing import (
@@ -27,12 +29,20 @@ from typing import (
 import cloudpickle
 import pydantic
 import wrapt
-from redis import Redis
+from redis.asyncio import Redis
 
 from reflex import constants
 from reflex.base import Base
-from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert
+from reflex.event import (
+    Event,
+    EventHandler,
+    EventSpec,
+    _no_chain_background_task,
+    fix_events,
+    window_alert,
+)
 from reflex.utils import format, prerequisites, types
+from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
 from reflex.vars import BaseVar, ComputedVar, Var
 
 Delta = Dict[str, Any]
@@ -152,7 +162,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
         # Convert the event handlers to functions.
         for name, event_handler in state.event_handlers.items():
-            fn = functools.partial(event_handler.fn, self)
+            if event_handler.is_background:
+                fn = _no_chain_background_task(type(state), name, event_handler.fn)
+            else:
+                fn = functools.partial(event_handler.fn, self)
             fn.__module__ = event_handler.fn.__module__  # type: ignore
             fn.__qualname__ = event_handler.fn.__qualname__  # type: ignore
             setattr(self, name, fn)
@@ -711,52 +724,56 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             raise ValueError(f"Invalid path: {path}")
         return self.substates[path[0]].get_substate(path[1:])
 
-    async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
-        """Obtain event info and process event.
+    def _get_event_handler(
+        self, event: Event
+    ) -> tuple[State | StateProxy, EventHandler]:
+        """Get the event handler for the given event.
 
         Args:
-            event: The event to process.
+            event: The event to get the handler for.
 
-        Yields:
-            The state update after processing the event.
+
+        Returns:
+            The event handler.
 
         Raises:
-            ValueError: If the state value is None.
+            ValueError: If the event handler or substate is not found.
         """
         # Get the event handler.
         path = event.name.split(".")
         path, name = path[:-1], path[-1]
         substate = self.get_substate(path)
-        handler = substate.event_handlers[name]  # type: ignore
-
         if not substate:
             raise ValueError(
                 "The value of state cannot be None when processing an event."
             )
+        handler = substate.event_handlers[name]
 
-        # Get the event generator.
-        event_iter = self._process_event(
-            handler=handler,
-            state=substate,
-            payload=event.payload,
-        )
+        # For background tasks, proxy the state
+        if handler.is_background:
+            substate = StateProxy(substate)
 
-        # Clean the state before processing the event.
-        self._clean()
+        return substate, handler
 
-        # Run the event generator and return state updates.
-        async for events, final in event_iter:
-            # Fix the returned events.
-            events = fix_events(events, event.token)  # type: ignore
+    async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
+        """Obtain event info and process event.
 
-            # Get the delta after processing the event.
-            delta = self.get_delta()
+        Args:
+            event: The event to process.
 
-            # Yield the state update.
-            yield StateUpdate(delta=delta, events=events, final=final)
+        Yields:
+            The state update after processing the event.
+        """
+        # Get the event handler.
+        substate, handler = self._get_event_handler(event)
 
-            # Clean the state to prepare for the next event.
-            self._clean()
+        # Run the event generator and yield state updates.
+        async for update in self._process_event(
+            handler=handler,
+            state=substate,
+            payload=event.payload,
+        ):
+            yield update
 
     def _check_valid(self, handler: EventHandler, events: Any) -> Any:
         """Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
@@ -787,9 +804,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
         )
 
+    def _as_state_update(
+        self,
+        handler: EventHandler,
+        events: EventSpec | list[EventSpec] | None,
+        final: bool,
+    ) -> StateUpdate:
+        """Convert the events to a StateUpdate.
+
+        Fixes the events and checks for validity before converting.
+
+        Args:
+            handler: The handler where the events originated from.
+            events: The events to queue with the update.
+            final: Whether the handler is done processing.
+
+        Returns:
+            The valid StateUpdate containing the events and final flag.
+        """
+        token = self.get_token()
+
+        # Convert valid EventHandler and EventSpec into Event
+        fixed_events = fix_events(self._check_valid(handler, events), token)
+
+        # Get the delta after processing the event.
+        delta = self.get_delta()
+        self._clean()
+
+        return StateUpdate(
+            delta=delta,
+            events=fixed_events,
+            final=final if not handler.is_background else True,
+        )
+
     async def _process_event(
-        self, handler: EventHandler, state: State, payload: Dict
-    ) -> AsyncIterator[tuple[list[EventSpec] | None, bool]]:
+        self, handler: EventHandler, state: State | StateProxy, payload: Dict
+    ) -> AsyncIterator[StateUpdate]:
         """Process event.
 
         Args:
@@ -798,13 +848,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             payload: The event payload.
 
         Yields:
-            Tuple containing:
-                0: The state update after processing the event.
-                1: Whether the event is the final event.
+            StateUpdate object
         """
         # Get the function to process the event.
         fn = functools.partial(handler.fn, state)
 
+        # Clean the state before processing the event.
+        self._clean()
+
         # Wrap the function in a try/except block.
         try:
             # Handle async functions.
@@ -817,30 +868,34 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             # Handle async generators.
             if inspect.isasyncgen(events):
                 async for event in events:
-                    yield self._check_valid(handler, event), False
-                yield None, True
+                    yield self._as_state_update(handler, event, final=False)
+                yield self._as_state_update(handler, events=None, final=True)
 
             # Handle regular generators.
             elif inspect.isgenerator(events):
                 try:
                     while True:
-                        yield self._check_valid(handler, next(events)), False
+                        yield self._as_state_update(handler, next(events), final=False)
                 except StopIteration as si:
                     # the "return" value of the generator is not available
                     # in the loop, we must catch StopIteration to access it
                     if si.value is not None:
-                        yield self._check_valid(handler, si.value), False
-                yield None, True
+                        yield self._as_state_update(handler, si.value, final=False)
+                yield self._as_state_update(handler, events=None, final=True)
 
             # Handle regular event chains.
             else:
-                yield self._check_valid(handler, events), True
+                yield self._as_state_update(handler, events, final=True)
 
         # If an error occurs, throw a window alert.
         except Exception:
             error = traceback.format_exc()
             print(error)
-            yield [window_alert("An error occurred. See logs for details.")], True
+            yield self._as_state_update(
+                handler,
+                window_alert("An error occurred. See logs for details."),
+                final=True,
+            )
 
     def _always_dirty_computed_vars(self) -> set[str]:
         """The set of ComputedVars that always need to be recalculated.
@@ -989,6 +1044,160 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         variables = {**base_vars, **computed_vars, **substate_vars}
         return {k: variables[k] for k in sorted(variables)}
 
+    async def __aenter__(self) -> State:
+        """Enter the async context manager protocol.
+
+        This should not be used for the State class, but exists for
+        type-compatibility with StateProxy.
+
+        Raises:
+            TypeError: always, because async contextmanager protocol is only supported for background task.
+        """
+        raise TypeError(
+            "Only background task should use `async with self` to modify state."
+        )
+
+    async def __aexit__(self, *exc_info: Any) -> None:
+        """Exit the async context manager protocol.
+
+        This should not be used for the State class, but exists for
+        type-compatibility with StateProxy.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        pass
+
+
+class StateProxy(wrapt.ObjectProxy):
+    """Proxy of a state instance to control mutability of vars for a background task.
+
+    Since a background task runs against a state instance without holding the
+    state_manager lock for the token, the reference may become stale if the same
+    state is modified by another event handler.
+
+    The proxy object ensures that writes to the state are blocked unless
+    explicitly entering a context which refreshes the state from state_manager
+    and holds the lock for the token until exiting the context. After exiting
+    the context, a StateUpdate may be emitted to the frontend to notify the
+    client of the state change.
+
+    A background task will be passed the `StateProxy` as `self`, so mutability
+    can be safely performed inside an `async with self` block.
+
+        class State(rx.State):
+            counter: int = 0
+
+            @rx.background
+            async def bg_increment(self):
+                await asyncio.sleep(1)
+                async with self:
+                    self.counter += 1
+    """
+
+    def __init__(self, state_instance):
+        """Create a proxy for a state instance.
+
+        Args:
+            state_instance: The state instance to proxy.
+        """
+        super().__init__(state_instance)
+        self._self_app = getattr(prerequisites.get_app(), constants.APP_VAR)
+        self._self_substate_path = state_instance.get_full_name().split(".")
+        self._self_actx = None
+        self._self_mutable = False
+
+    async def __aenter__(self) -> StateProxy:
+        """Enter the async context manager protocol.
+
+        Sets mutability to True and enters the `App.modify_state` async context,
+        which refreshes the state from state_manager and holds the lock for the
+        given state token until exiting the context.
+
+        Background tasks should avoid blocking calls while inside the context.
+
+        Returns:
+            This StateProxy instance in mutable mode.
+        """
+        self._self_actx = self._self_app.modify_state(self.__wrapped__.get_token())
+        mutable_state = await self._self_actx.__aenter__()
+        super().__setattr__(
+            "__wrapped__", mutable_state.get_substate(self._self_substate_path)
+        )
+        self._self_mutable = True
+        return self
+
+    async def __aexit__(self, *exc_info: Any) -> None:
+        """Exit the async context manager protocol.
+
+        Sets proxy mutability to False and persists any state changes.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        if self._self_actx is None:
+            return
+        self._self_mutable = False
+        await self._self_actx.__aexit__(*exc_info)
+        self._self_actx = None
+
+    def __enter__(self):
+        """Enter the regular context manager protocol.
+
+        This is not supported for background tasks, and exists only to raise a more useful exception
+        when the StateProxy is used incorrectly.
+
+        Raises:
+            TypeError: always, because only async contextmanager protocol is supported.
+        """
+        raise TypeError("Background task must use `async with self` to modify state.")
+
+    def __exit__(self, *exc_info: Any) -> None:
+        """Exit the regular context manager protocol.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        pass
+
+    def __getattr__(self, name: str) -> Any:
+        """Get the attribute from the underlying state instance.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The value of the attribute.
+        """
+        value = super().__getattr__(name)
+        if not name.startswith("_self_") and isinstance(value, MutableProxy):
+            # ensure mutations to these containers are blocked unless proxy is _mutable
+            return ImmutableMutableProxy(
+                wrapped=value.__wrapped__,
+                state=self,  # type: ignore
+                field_name=value._self_field_name,
+            )
+        return value
+
+    def __setattr__(self, name: str, value: Any) -> None:
+        """Set the attribute on the underlying state instance.
+
+        If the attribute is internal, set it on the proxy instance instead.
+
+        Args:
+            name: The name of the attribute.
+            value: The value of the attribute.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not name.startswith("_self_") and not self._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        super().__setattr__(name, value)
+
 
 class DefaultState(State):
     """The default empty state."""
@@ -1009,31 +1218,83 @@ class StateUpdate(Base):
     final: bool = True
 
 
-class StateManager(Base):
+class StateManager(Base, ABC):
     """A class to manage many client states."""
 
     # The state class to use.
-    state: Type[State] = DefaultState
+    state: Type[State]
 
-    # The mapping of client ids to states.
-    states: Dict[str, State] = {}
+    @classmethod
+    def create(cls, state: Type[State] = DefaultState):
+        """Create a new state manager.
 
-    # The token expiration time (s).
-    token_expiration: int = constants.TOKEN_EXPIRATION
+        Args:
+            state: The state class to use.
 
-    # The redis client to use.
-    redis: Optional[Redis] = None
+        Returns:
+            The state manager (either memory or redis).
+        """
+        redis = prerequisites.get_redis()
+        if redis is not None:
+            return StateManagerRedis(state=state, redis=redis)
+        return StateManagerMemory(state=state)
 
-    def setup(self, state: Type[State]):
-        """Set up the state manager.
+    @abstractmethod
+    async def get_state(self, token: str) -> State:
+        """Get the state for a token.
 
         Args:
-            state: The state class to use.
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        pass
+
+    @abstractmethod
+    async def set_state(self, token: str, state: State):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+        """
+        pass
+
+    @abstractmethod
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
         """
-        self.state = state
-        self.redis = prerequisites.get_redis()
+        yield self.state()
 
-    def get_state(self, token: str) -> State:
+
+class StateManagerMemory(StateManager):
+    """A state manager that stores states in memory."""
+
+    # The mapping of client ids to states.
+    states: Dict[str, State] = {}
+
+    # The mutex ensures the dict of mutexes is updated exclusively
+    _state_manager_lock = asyncio.Lock()
+
+    # The dict of mutexes for each client
+    _states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
+
+    class Config:
+        """The Pydantic config."""
+
+        fields = {
+            "_states_locks": {"exclude": True},
+        }
+
+    async def get_state(self, token: str) -> State:
         """Get the state for a token.
 
         Args:
@@ -1042,27 +1303,212 @@ class StateManager(Base):
         Returns:
             The state for the token.
         """
-        if self.redis is not None:
-            redis_state = self.redis.get(token)
-            if redis_state is None:
-                self.set_state(token, self.state())
-                return self.get_state(token)
-            return cloudpickle.loads(redis_state)
-
         if token not in self.states:
             self.states[token] = self.state()
         return self.states[token]
 
-    def set_state(self, token: str, state: State):
+    async def set_state(self, token: str, state: State):
         """Set the state for a token.
 
         Args:
             token: The token to set the state for.
             state: The state to set.
         """
-        if self.redis is None:
-            return
-        self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
+        pass
+
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        if token not in self._states_locks:
+            async with self._state_manager_lock:
+                if token not in self._states_locks:
+                    self._states_locks[token] = asyncio.Lock()
+
+        async with self._states_locks[token]:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state)
+
+
+class StateManagerRedis(StateManager):
+    """A state manager that stores states in redis."""
+
+    # The redis client to use.
+    redis: Redis
+
+    # The token expiration time (s).
+    token_expiration: int = constants.TOKEN_EXPIRATION
+
+    # The maximum time to hold a lock (ms).
+    lock_expiration: int = constants.LOCK_EXPIRATION
+
+    # The keyspace subscription string when redis is waiting for lock to be released
+    _redis_notify_keyspace_events: str = (
+        "K"  # Enable keyspace notifications (target a particular key)
+        "g"  # For generic commands (DEL, EXPIRE, etc)
+        "x"  # For expired events
+        "e"  # For evicted events (i.e. maxmemory exceeded)
+    )
+
+    # These events indicate that a lock is no longer held
+    _redis_keyspace_lock_release_events: Set[bytes] = {
+        b"del",
+        b"expire",
+        b"expired",
+        b"evicted",
+    }
+
+    async def get_state(self, token: str) -> State:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        redis_state = await self.redis.get(token)
+        if redis_state is None:
+            await self.set_state(token, self.state())
+            return await self.get_state(token)
+        return cloudpickle.loads(redis_state)
+
+    async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+            lock_id: If provided, the lock_key must be set to this value to set the state.
+
+        Raises:
+            LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
+        """
+        # check that we're holding the lock
+        if (
+            lock_id is not None
+            and await self.redis.get(self._lock_key(token)) != lock_id
+        ):
+            raise LockExpiredError(
+                f"Lock expired for token {token} while processing. Consider increasing "
+                f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
+                "or use `@rx.background` decorator for long-running tasks."
+            )
+        await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
+
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        async with self._lock(token) as lock_id:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state, lock_id)
+
+    @staticmethod
+    def _lock_key(token: str) -> bytes:
+        """Get the redis key for a token's lock.
+
+        Args:
+            token: The token to get the lock key for.
+
+        Returns:
+            The redis lock key for the token.
+        """
+        return f"{token}_lock".encode()
+
+    async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
+        """Try to get a redis lock for a token.
+
+        Args:
+            lock_key: The redis key for the lock.
+            lock_id: The ID of the lock.
+
+        Returns:
+            True if the lock was obtained.
+        """
+        return await self.redis.set(
+            lock_key,
+            lock_id,
+            px=self.lock_expiration,
+            nx=True,  # only set if it doesn't exist
+        )
+
+    async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
+        """Wait for a redis lock to be released via pubsub.
+
+        Coroutine will not return until the lock is obtained.
+
+        Args:
+            lock_key: The redis key for the lock.
+            lock_id: The ID of the lock.
+        """
+        state_is_locked = False
+        lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
+        # Enable keyspace notifications for the lock key, so we know when it is available.
+        await self.redis.config_set(
+            "notify-keyspace-events", self._redis_notify_keyspace_events
+        )
+        async with self.redis.pubsub() as pubsub:
+            await pubsub.psubscribe(lock_key_channel)
+            while not state_is_locked:
+                # wait for the lock to be released
+                while True:
+                    if not await self.redis.exists(lock_key):
+                        break  # key was removed, try to get the lock again
+                    message = await pubsub.get_message(
+                        ignore_subscribe_messages=True,
+                        timeout=self.lock_expiration / 1000.0,
+                    )
+                    if message is None:
+                        continue
+                    if message["data"] in self._redis_keyspace_lock_release_events:
+                        break
+                state_is_locked = await self._try_get_lock(lock_key, lock_id)
+
+    @contextlib.asynccontextmanager
+    async def _lock(self, token: str):
+        """Obtain a redis lock for a token.
+
+        Args:
+            token: The token to obtain a lock for.
+
+        Yields:
+            The ID of the lock (to be passed to set_state).
+
+        Raises:
+            LockExpiredError: If the lock has expired while processing the event.
+        """
+        lock_key = self._lock_key(token)
+        lock_id = uuid.uuid4().hex.encode()
+
+        if not await self._try_get_lock(lock_key, lock_id):
+            # Missed the fast-path to get lock, subscribe for lock delete/expire events
+            await self._wait_lock(lock_key, lock_id)
+        state_is_locked = True
+
+        try:
+            yield lock_id
+        except LockExpiredError:
+            state_is_locked = False
+            raise
+        finally:
+            if state_is_locked:
+                # only delete our lock
+                await self.redis.delete(lock_key)
 
 
 class ClientStorageBase:
@@ -1246,7 +1692,7 @@ class MutableProxy(wrapt.ObjectProxy):
             value, super().__getattribute__("__mutable_types__")
         ) and __name not in ("__wrapped__", "_self_state"):
             # Recursively wrap mutable attribute values retrieved through this proxy.
-            return MutableProxy(
+            return type(self)(
                 wrapped=value,
                 state=self._self_state,
                 field_name=self._self_field_name,
@@ -1266,7 +1712,7 @@ class MutableProxy(wrapt.ObjectProxy):
         value = super().__getitem__(key)
         if isinstance(value, self.__mutable_types__):
             # Recursively wrap mutable items retrieved through this proxy.
-            return MutableProxy(
+            return type(self)(
                 wrapped=value,
                 state=self._self_state,
                 field_name=self._self_field_name,
@@ -1332,3 +1778,34 @@ class MutableProxy(wrapt.ObjectProxy):
             A deepcopy of the wrapped object, unconnected to the proxy.
         """
         return copy.deepcopy(self.__wrapped__, memo=memo)
+
+
+class ImmutableMutableProxy(MutableProxy):
+    """A proxy for a mutable object that tracks changes.
+
+    This wrapper comes from StateProxy, and will raise an exception if an attempt is made
+    to modify the wrapped object when the StateProxy is immutable.
+    """
+
+    def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None):
+        """Raise an exception when an attempt is made to modify the object.
+
+        Intended for use with `FunctionWrapper` from the `wrapt` library.
+
+        Args:
+            wrapped: The wrapped function.
+            instance: The instance of the wrapped function.
+            args: The args for the wrapped function.
+            kwargs: The kwargs for the wrapped function.
+
+        Raises:
+            ImmutableStateError: if the StateProxy is not mutable.
+        """
+        if not self._self_state._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        super()._mark_dirty(
+            wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
+        )

+ 122 - 30
reflex/testing.py

@@ -1,6 +1,7 @@
 """reflex.testing - tools for testing reflex apps."""
 from __future__ import annotations
 
+import asyncio
 import contextlib
 import dataclasses
 import inspect
@@ -19,14 +20,13 @@ import types
 from http.server import SimpleHTTPRequestHandler
 from typing import (
     TYPE_CHECKING,
-    Any,
+    AsyncIterator,
     Callable,
     Coroutine,
     Optional,
     Type,
     TypeVar,
     Union,
-    cast,
 )
 
 import psutil
@@ -38,7 +38,7 @@ import reflex.utils.build
 import reflex.utils.exec
 import reflex.utils.prerequisites
 import reflex.utils.processes
-from reflex.app import EventNamespace
+from reflex.state import State, StateManagerMemory, StateManagerRedis
 
 try:
     from selenium import webdriver  # pyright: ignore [reportMissingImports]
@@ -109,6 +109,7 @@ class AppHarness:
     frontend_url: Optional[str] = None
     backend_thread: Optional[threading.Thread] = None
     backend: Optional[uvicorn.Server] = None
+    state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
     _frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
 
     @classmethod
@@ -162,6 +163,27 @@ class AppHarness:
             reflex.config.get_config(reload=True)
             self.app_module = reflex.utils.prerequisites.get_app(reload=True)
         self.app_instance = self.app_module.app
+        if isinstance(self.app_instance.state_manager, StateManagerRedis):
+            # Create our own redis connection for testing.
+            self.state_manager = StateManagerRedis.create(self.app_instance.state)
+        else:
+            self.state_manager = self.app_instance.state_manager
+
+    def _get_backend_shutdown_handler(self):
+        if self.backend is None:
+            raise RuntimeError("Backend was not initialized.")
+
+        original_shutdown = self.backend.shutdown
+
+        async def _shutdown_redis(*args, **kwargs) -> None:
+            # ensure redis is closed before event loop
+            if self.app_instance is not None and isinstance(
+                self.app_instance.state_manager, StateManagerRedis
+            ):
+                await self.app_instance.state_manager.redis.close()
+            await original_shutdown(*args, **kwargs)
+
+        return _shutdown_redis
 
     def _start_backend(self, port=0):
         if self.app_instance is None:
@@ -173,6 +195,7 @@ class AppHarness:
                 port=port,
             )
         )
+        self.backend.shutdown = self._get_backend_shutdown_handler()
         self.backend_thread = threading.Thread(target=self.backend.run)
         self.backend_thread.start()
 
@@ -296,6 +319,35 @@ class AppHarness:
             time.sleep(step)
         return False
 
+    @staticmethod
+    async def _poll_for_async(
+        target: Callable[[], Coroutine[None, None, T]],
+        timeout: TimeoutType = None,
+        step: TimeoutType = None,
+    ) -> T | bool:
+        """Generic polling logic for async functions.
+
+        Args:
+            target: callable that returns truthy if polling condition is met.
+            timeout: max polling time
+            step: interval between checking target()
+
+        Returns:
+            return value of target() if truthy within timeout
+            False if timeout elapses
+        """
+        if timeout is None:
+            timeout = DEFAULT_TIMEOUT
+        if step is None:
+            step = POLL_INTERVAL
+        deadline = time.time() + timeout
+        while time.time() < deadline:
+            success = await target()
+            if success:
+                return success
+            await asyncio.sleep(step)
+        return False
+
     def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
         """Poll backend server for listening sockets.
 
@@ -351,39 +403,76 @@ class AppHarness:
         self._frontends.append(driver)
         return driver
 
-    async def emit_state_updates(self) -> list[Any]:
-        """Send any backend state deltas to the frontend.
+    async def get_state(self, token: str) -> State:
+        """Get the state associated with the given token.
+
+        Args:
+            token: The state token to look up.
 
         Returns:
-            List of awaited response from each EventNamespace.emit() call.
+            The state instance associated with the given token
+
+        Raises:
+            RuntimeError: when the app hasn't started running
+        """
+        if self.state_manager is None:
+            raise RuntimeError("state_manager is not set.")
+        try:
+            return await self.state_manager.get_state(token)
+        finally:
+            if isinstance(self.state_manager, StateManagerRedis):
+                await self.state_manager.redis.close()
+
+    async def set_state(self, token: str, **kwargs) -> None:
+        """Set the state associated with the given token.
+
+        Args:
+            token: The state token to set.
+            kwargs: Attributes to set on the state.
+
+        Raises:
+            RuntimeError: when the app hasn't started running
+        """
+        if self.state_manager is None:
+            raise RuntimeError("state_manager is not set.")
+        state = await self.get_state(token)
+        for key, value in kwargs.items():
+            setattr(state, key, value)
+        try:
+            await self.state_manager.set_state(token, state)
+        finally:
+            if isinstance(self.state_manager, StateManagerRedis):
+                await self.state_manager.redis.close()
+
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state associated with the given token and send update to frontend.
+
+        Args:
+            token: The state token to modify
+
+        Yields:
+            The state instance associated with the given token
 
         Raises:
             RuntimeError: when the app hasn't started running
         """
-        if self.app_instance is None or self.app_instance.sio is None:
+        if self.state_manager is None:
+            raise RuntimeError("state_manager is not set.")
+        if self.app_instance is None:
             raise RuntimeError("App is not running.")
-        event_ns: EventNamespace = cast(
-            EventNamespace,
-            self.app_instance.event_namespace,
-        )
-        pending: list[Coroutine[Any, Any, Any]] = []
-        for state in self.app_instance.state_manager.states.values():
-            delta = state.get_delta()
-            if delta:
-                update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
-                state._clean()
-                # Emit the event.
-                pending.append(
-                    event_ns.emit(
-                        str(reflex.constants.SocketEvent.EVENT),
-                        update.json(),
-                        to=state.get_sid(),
-                    ),
-                )
-        responses = []
-        for request in pending:
-            responses.append(await request)
-        return responses
+        app_state_manager = self.app_instance.state_manager
+        if isinstance(self.state_manager, StateManagerRedis):
+            # Temporarily replace the app's state manager with our own, since
+            # the redis connection is on the backend_thread event loop
+            self.app_instance.state_manager = self.state_manager
+        try:
+            async with self.app_instance.modify_state(token) as state:
+                yield state
+        finally:
+            if isinstance(self.state_manager, StateManagerRedis):
+                self.app_instance.state_manager = app_state_manager
+                await self.state_manager.redis.close()
 
     def poll_for_content(
         self,
@@ -457,6 +546,9 @@ class AppHarness:
         if self.app_instance is None:
             raise RuntimeError("App is not running.")
         state_manager = self.app_instance.state_manager
+        assert isinstance(
+            state_manager, StateManagerMemory
+        ), "Only works with memory state manager"
         if not self._poll_for(
             target=lambda: state_manager.states,
             timeout=timeout,
@@ -534,7 +626,6 @@ class Subdir404TCPServer(socketserver.TCPServer):
             request: the requesting socket
             client_address: (host, port) referring to the client’s address.
         """
-        print(client_address, type(client_address))
         self.RequestHandlerClass(
             request,
             client_address,
@@ -605,6 +696,7 @@ class AppHarnessProd(AppHarness):
                 workers=reflex.utils.processes.get_num_workers(),
             ),
         )
+        self.backend.shutdown = self._get_backend_shutdown_handler()
         self.backend_thread = threading.Thread(target=self.backend.run)
         self.backend_thread.start()
 

+ 8 - 0
reflex/utils/exceptions.py

@@ -5,3 +5,11 @@ class InvalidStylePropError(TypeError):
     """Custom Type Error when style props have invalid values."""
 
     pass
+
+
+class ImmutableStateError(AttributeError):
+    """Raised when a background task attempts to modify state outside of context."""
+
+
+class LockExpiredError(Exception):
+    """Raised when the state lock expires while an event is being processed."""

+ 5 - 3
reflex/utils/prerequisites.py

@@ -21,7 +21,7 @@ import httpx
 import typer
 from alembic.util.exc import CommandError
 from packaging import version
-from redis import Redis
+from redis.asyncio import Redis
 
 from reflex import constants, model
 from reflex.compiler import templates
@@ -124,9 +124,11 @@ def get_redis() -> Redis | None:
         The redis client.
     """
     config = get_config()
-    if config.redis_url is None:
+    if not config.redis_url:
         return None
-    redis_url, redis_port = config.redis_url.split(":")
+    redis_url, has_port, redis_port = config.redis_url.partition(":")
+    if not has_port:
+        redis_port = 6379
     console.info(f"Using redis at {config.redis_url}")
     return Redis(host=redis_url, port=int(redis_port), db=0)
 

+ 21 - 380
tests/conftest.py

@@ -2,8 +2,9 @@
 import contextlib
 import os
 import platform
+import uuid
 from pathlib import Path
-from typing import Dict, Generator, List, Set, Union
+from typing import Dict, Generator
 
 import pytest
 
@@ -11,6 +12,14 @@ import reflex as rx
 from reflex.app import App
 from reflex.event import EventSpec
 
+from .states import (
+    DictMutationTestState,
+    ListMutationTestState,
+    MutableTestState,
+    SubUploadState,
+    UploadState,
+)
+
 
 @pytest.fixture
 def app() -> App:
@@ -39,60 +48,7 @@ def list_mutation_state():
     Returns:
         A state with list mutation features.
     """
-
-    class TestState(rx.State):
-        """The test state."""
-
-        # plain list
-        plain_friends = ["Tommy"]
-
-        def make_friend(self):
-            self.plain_friends.append("another-fd")
-
-        def change_first_friend(self):
-            self.plain_friends[0] = "Jenny"
-
-        def unfriend_all_friends(self):
-            self.plain_friends.clear()
-
-        def unfriend_first_friend(self):
-            del self.plain_friends[0]
-
-        def remove_last_friend(self):
-            self.plain_friends.pop()
-
-        def make_friends_with_colleagues(self):
-            colleagues = ["Peter", "Jimmy"]
-            self.plain_friends.extend(colleagues)
-
-        def remove_tommy(self):
-            self.plain_friends.remove("Tommy")
-
-        # list in dict
-        friends_in_dict = {"Tommy": ["Jenny"]}
-
-        def remove_jenny_from_tommy(self):
-            self.friends_in_dict["Tommy"].remove("Jenny")
-
-        def add_jimmy_to_tommy_friends(self):
-            self.friends_in_dict["Tommy"].append("Jimmy")
-
-        def tommy_has_no_fds(self):
-            self.friends_in_dict["Tommy"].clear()
-
-        # nested list
-        friends_in_nested_list = [["Tommy"], ["Jenny"]]
-
-        def remove_first_group(self):
-            self.friends_in_nested_list.pop(0)
-
-        def remove_first_person_from_first_group(self):
-            self.friends_in_nested_list[0].pop(0)
-
-        def add_jimmy_to_second_group(self):
-            self.friends_in_nested_list[1].append("Jimmy")
-
-    return TestState()
+    return ListMutationTestState()
 
 
 @pytest.fixture
@@ -102,85 +58,7 @@ def dict_mutation_state():
     Returns:
         A state with dict mutation features.
     """
-
-    class TestState(rx.State):
-        """The test state."""
-
-        # plain dict
-        details = {"name": "Tommy"}
-
-        def add_age(self):
-            self.details.update({"age": 20})  # type: ignore
-
-        def change_name(self):
-            self.details["name"] = "Jenny"
-
-        def remove_last_detail(self):
-            self.details.popitem()
-
-        def clear_details(self):
-            self.details.clear()
-
-        def remove_name(self):
-            del self.details["name"]
-
-        def pop_out_age(self):
-            self.details.pop("age")
-
-        # dict in list
-        address = [{"home": "home address"}, {"work": "work address"}]
-
-        def remove_home_address(self):
-            self.address[0].pop("home")
-
-        def add_street_to_home_address(self):
-            self.address[0]["street"] = "street address"
-
-        # nested dict
-        friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
-
-        def change_friend_name(self):
-            self.friend_in_nested_dict["friend"]["name"] = "Tommy"
-
-        def remove_friend(self):
-            self.friend_in_nested_dict.pop("friend")
-
-        def add_friend_age(self):
-            self.friend_in_nested_dict["friend"]["age"] = 30
-
-    return TestState()
-
-
-class UploadState(rx.State):
-    """The base state for uploading a file."""
-
-    async def handle_upload1(self, files: List[rx.UploadFile]):
-        """Handle the upload of a file.
-
-        Args:
-            files: The uploaded files.
-        """
-        pass
-
-
-class BaseState(rx.State):
-    """The test base state."""
-
-    pass
-
-
-class SubUploadState(BaseState):
-    """The test substate."""
-
-    img: str
-
-    async def handle_upload(self, files: List[rx.UploadFile]):
-        """Handle the upload of a file.
-
-        Args:
-            files: The uploaded files.
-        """
-        pass
+    return DictMutationTestState()
 
 
 @pytest.fixture
@@ -203,187 +81,6 @@ def upload_event_spec():
     return EventSpec(handler=UploadState.handle_upload1, upload=True)  # type: ignore
 
 
-@pytest.fixture
-def upload_state(tmp_path):
-    """Create upload state.
-
-    Args:
-        tmp_path: pytest tmp_path
-
-    Returns:
-        The state
-
-    """
-
-    class FileUploadState(rx.State):
-        """The base state for uploading a file."""
-
-        img_list: List[str]
-
-        async def handle_upload2(self, files):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                self.img_list.append(file.filename)
-
-        async def multi_handle_upload(self, files: List[rx.UploadFile]):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                assert file.filename is not None
-                self.img_list.append(file.filename)
-
-    return FileUploadState
-
-
-@pytest.fixture
-def upload_sub_state(tmp_path):
-    """Create upload substate.
-
-    Args:
-        tmp_path: pytest tmp_path
-
-    Returns:
-        The state
-
-    """
-
-    class FileState(rx.State):
-        """The base state."""
-
-        pass
-
-    class FileUploadState(FileState):
-        """The substate for uploading a file."""
-
-        img_list: List[str]
-
-        async def handle_upload2(self, files):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                self.img_list.append(file.filename)
-
-        async def multi_handle_upload(self, files: List[rx.UploadFile]):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                assert file.filename is not None
-                self.img_list.append(file.filename)
-
-    return FileUploadState
-
-
-@pytest.fixture
-def upload_grand_sub_state(tmp_path):
-    """Create upload grand-state.
-
-    Args:
-        tmp_path: pytest tmp_path
-
-    Returns:
-        The state
-
-    """
-
-    class BaseFileState(rx.State):
-        """The base state."""
-
-        pass
-
-    class FileSubState(BaseFileState):
-        """The substate."""
-
-        pass
-
-    class FileUploadState(FileSubState):
-        """The grand-substate for uploading a file."""
-
-        img_list: List[str]
-
-        async def handle_upload2(self, files):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                assert file.filename is not None
-                self.img_list.append(file.filename)
-
-        async def multi_handle_upload(self, files: List[rx.UploadFile]):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                assert file.filename is not None
-                self.img_list.append(file.filename)
-
-    return FileUploadState
-
-
 @pytest.fixture
 def base_config_values() -> Dict:
     """Get base config values.
@@ -418,35 +115,6 @@ def sqlite_db_config_values(base_db_config_values) -> Dict:
     return base_db_config_values
 
 
-class GenState(rx.State):
-    """A state with event handlers that generate multiple updates."""
-
-    value: int
-
-    def go(self, c: int):
-        """Increment the value c times and update each time.
-
-        Args:
-            c: The number of times to increment.
-
-        Yields:
-            After each increment.
-        """
-        for _ in range(c):
-            self.value += 1
-            yield
-
-
-@pytest.fixture
-def gen_state() -> GenState:
-    """A state.
-
-    Returns:
-        A test state.
-    """
-    return GenState  # type: ignore
-
-
 @pytest.fixture
 def router_data_headers() -> Dict[str, str]:
     """Router data headers.
@@ -546,44 +214,17 @@ def mutable_state():
     Returns:
         A state object.
     """
+    return MutableTestState()
 
-    class OtherBase(rx.Base):
-        bar: str = ""
-
-    class CustomVar(rx.Base):
-        foo: str = ""
-        array: List[str] = []
-        hashmap: Dict[str, str] = {}
-        test_set: Set[str] = set()
-        custom: OtherBase = OtherBase()
-
-    class MutableTestState(rx.State):
-        """A test state."""
-
-        array: List[Union[str, List, Dict[str, str]]] = [
-            "value",
-            [1, 2, 3],
-            {"key": "value"},
-        ]
-        hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
-            "key": ["list", "of", "values"],
-            "another_key": "another_value",
-            "third_key": {"key": "value"},
-        }
-        test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
-        custom: CustomVar = CustomVar()
-        _be_custom: CustomVar = CustomVar()
-
-        def reassign_mutables(self):
-            self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
-            self.hashmap = {
-                "mod_key": ["list", "of", "values"],
-                "mod_another_key": "another_value",
-                "mod_third_key": {"key": "value"},
-            }
-            self.test_set = {1, 2, 3, 4, "five"}
 
-    return MutableTestState()
+@pytest.fixture(scope="function")
+def token() -> str:
+    """Create a token.
+
+    Returns:
+        A fresh/unique token string.
+    """
+    return str(uuid.uuid4())
 
 
 @pytest.fixture

+ 30 - 0
tests/states/__init__.py

@@ -0,0 +1,30 @@
+"""Common rx.State subclasses for use in tests."""
+import reflex as rx
+
+from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
+from .upload import (
+    ChildFileUploadState,
+    FileUploadState,
+    GrandChildFileUploadState,
+    SubUploadState,
+    UploadState,
+)
+
+
+class GenState(rx.State):
+    """A state with event handlers that generate multiple updates."""
+
+    value: int
+
+    def go(self, c: int):
+        """Increment the value c times and update each time.
+
+        Args:
+            c: The number of times to increment.
+
+        Yields:
+            After each increment.
+        """
+        for _ in range(c):
+            self.value += 1
+            yield

+ 172 - 0
tests/states/mutation.py

@@ -0,0 +1,172 @@
+"""Test states for mutable vars."""
+
+from typing import Dict, List, Set, Union
+
+import reflex as rx
+
+
+class DictMutationTestState(rx.State):
+    """A state for testing ReflexDict mutation."""
+
+    # plain dict
+    details = {"name": "Tommy"}
+
+    def add_age(self):
+        """Add an age to the dict."""
+        self.details.update({"age": 20})  # type: ignore
+
+    def change_name(self):
+        """Change the name in the dict."""
+        self.details["name"] = "Jenny"
+
+    def remove_last_detail(self):
+        """Remove the last item in the dict."""
+        self.details.popitem()
+
+    def clear_details(self):
+        """Clear the dict."""
+        self.details.clear()
+
+    def remove_name(self):
+        """Remove the name from the dict."""
+        del self.details["name"]
+
+    def pop_out_age(self):
+        """Pop out the age from the dict."""
+        self.details.pop("age")
+
+    # dict in list
+    address = [{"home": "home address"}, {"work": "work address"}]
+
+    def remove_home_address(self):
+        """Remove the home address from dict in the list."""
+        self.address[0].pop("home")
+
+    def add_street_to_home_address(self):
+        """Set street key in the dict in the list."""
+        self.address[0]["street"] = "street address"
+
+    # nested dict
+    friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
+
+    def change_friend_name(self):
+        """Change the friend's name in the nested dict."""
+        self.friend_in_nested_dict["friend"]["name"] = "Tommy"
+
+    def remove_friend(self):
+        """Remove the friend from the nested dict."""
+        self.friend_in_nested_dict.pop("friend")
+
+    def add_friend_age(self):
+        """Add an age to the friend in the nested dict."""
+        self.friend_in_nested_dict["friend"]["age"] = 30
+
+
+class ListMutationTestState(rx.State):
+    """A state for testing ReflexList mutation."""
+
+    # plain list
+    plain_friends = ["Tommy"]
+
+    def make_friend(self):
+        """Add a friend to the list."""
+        self.plain_friends.append("another-fd")
+
+    def change_first_friend(self):
+        """Change the first friend in the list."""
+        self.plain_friends[0] = "Jenny"
+
+    def unfriend_all_friends(self):
+        """Unfriend all friends in the list."""
+        self.plain_friends.clear()
+
+    def unfriend_first_friend(self):
+        """Unfriend the first friend in the list."""
+        del self.plain_friends[0]
+
+    def remove_last_friend(self):
+        """Remove the last friend in the list."""
+        self.plain_friends.pop()
+
+    def make_friends_with_colleagues(self):
+        """Add list of friends to the list."""
+        colleagues = ["Peter", "Jimmy"]
+        self.plain_friends.extend(colleagues)
+
+    def remove_tommy(self):
+        """Remove Tommy from the list."""
+        self.plain_friends.remove("Tommy")
+
+    # list in dict
+    friends_in_dict = {"Tommy": ["Jenny"]}
+
+    def remove_jenny_from_tommy(self):
+        """Remove Jenny from Tommy's friends list."""
+        self.friends_in_dict["Tommy"].remove("Jenny")
+
+    def add_jimmy_to_tommy_friends(self):
+        """Add Jimmy to Tommy's friends list."""
+        self.friends_in_dict["Tommy"].append("Jimmy")
+
+    def tommy_has_no_fds(self):
+        """Clear Tommy's friends list."""
+        self.friends_in_dict["Tommy"].clear()
+
+    # nested list
+    friends_in_nested_list = [["Tommy"], ["Jenny"]]
+
+    def remove_first_group(self):
+        """Remove the first group of friends from the nested list."""
+        self.friends_in_nested_list.pop(0)
+
+    def remove_first_person_from_first_group(self):
+        """Remove the first person from the first group of friends in the nested list."""
+        self.friends_in_nested_list[0].pop(0)
+
+    def add_jimmy_to_second_group(self):
+        """Add Jimmy to the second group of friends in the nested list."""
+        self.friends_in_nested_list[1].append("Jimmy")
+
+
+class OtherBase(rx.Base):
+    """A Base model with a str field."""
+
+    bar: str = ""
+
+
+class CustomVar(rx.Base):
+    """A Base model with multiple fields."""
+
+    foo: str = ""
+    array: List[str] = []
+    hashmap: Dict[str, str] = {}
+    test_set: Set[str] = set()
+    custom: OtherBase = OtherBase()
+
+
+class MutableTestState(rx.State):
+    """A test state."""
+
+    array: List[Union[str, List, Dict[str, str]]] = [
+        "value",
+        [1, 2, 3],
+        {"key": "value"},
+    ]
+    hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
+        "key": ["list", "of", "values"],
+        "another_key": "another_value",
+        "third_key": {"key": "value"},
+    }
+    test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
+    custom: CustomVar = CustomVar()
+    _be_custom: CustomVar = CustomVar()
+
+    def reassign_mutables(self):
+        """Assign mutable fields to different values."""
+        self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
+        self.hashmap = {
+            "mod_key": ["list", "of", "values"],
+            "mod_another_key": "another_value",
+            "mod_third_key": {"key": "value"},
+        }
+        self.test_set = {1, 2, 3, 4, "five"}

+ 175 - 0
tests/states/upload.py

@@ -0,0 +1,175 @@
+"""Test states for upload-related tests."""
+from pathlib import Path
+from typing import ClassVar, List
+
+import reflex as rx
+
+
+class UploadState(rx.State):
+    """The base state for uploading a file."""
+
+    async def handle_upload1(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass
+
+
+class BaseState(rx.State):
+    """The test base state."""
+
+    pass
+
+
+class SubUploadState(BaseState):
+    """The test substate."""
+
+    img: str
+
+    async def handle_upload(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass
+
+
+class FileUploadState(rx.State):
+    """The base state for uploading a file."""
+
+    img_list: List[str]
+    _tmp_path: ClassVar[Path]
+
+    async def handle_upload2(self, files):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            self.img_list.append(file.filename)
+
+    async def multi_handle_upload(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            assert file.filename is not None
+            self.img_list.append(file.filename)
+
+
+class FileStateBase1(rx.State):
+    """The base state for a child FileUploadState."""
+
+    pass
+
+
+class ChildFileUploadState(FileStateBase1):
+    """The child state for uploading a file."""
+
+    img_list: List[str]
+    _tmp_path: ClassVar[Path]
+
+    async def handle_upload2(self, files):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            self.img_list.append(file.filename)
+
+    async def multi_handle_upload(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            assert file.filename is not None
+            self.img_list.append(file.filename)
+
+
+class FileStateBase2(FileStateBase1):
+    """The parent state for a grandchild FileUploadState."""
+
+    pass
+
+
+class GrandChildFileUploadState(FileStateBase2):
+    """The child state for uploading a file."""
+
+    img_list: List[str]
+    _tmp_path: ClassVar[Path]
+
+    async def handle_upload2(self, files):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            self.img_list.append(file.filename)
+
+    async def multi_handle_upload(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            assert file.filename is not None
+            self.img_list.append(file.filename)

+ 237 - 121
tests/test_app.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 import io
 import os.path
 import sys
+import uuid
 from typing import List, Tuple, Type
 
 if sys.version_info.major >= 3 and sys.version_info.minor > 7:
@@ -30,11 +31,18 @@ from reflex.components import Box, Component, Cond, Fragment, Text
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
-from reflex.state import State, StateUpdate
+from reflex.state import State, StateManagerRedis, StateUpdate
 from reflex.style import Style
 from reflex.utils import format
 from reflex.vars import ComputedVar
 
+from .states import (
+    ChildFileUploadState,
+    FileUploadState,
+    GenState,
+    GrandChildFileUploadState,
+)
+
 
 @pytest.fixture
 def index_page():
@@ -64,6 +72,12 @@ def about_page():
     return about
 
 
+class ATestState(State):
+    """A simple state for testing."""
+
+    var: int
+
+
 @pytest.fixture()
 def test_state() -> Type[State]:
     """A default state.
@@ -71,11 +85,7 @@ def test_state() -> Type[State]:
     Returns:
         A default state.
     """
-
-    class TestState(State):
-        var: int
-
-    return TestState
+    return ATestState
 
 
 @pytest.fixture()
@@ -313,23 +323,28 @@ def test_initialize_admin_dashboard_with_view_overrides(test_model):
     assert app.admin_dash.view_overrides[test_model] == TestModelView
 
 
-def test_initialize_with_state(test_state):
+@pytest.mark.asyncio
+async def test_initialize_with_state(test_state: Type[ATestState], token: str):
     """Test setting the state of an app.
 
     Args:
         test_state: The default state.
+        token: a Token.
     """
     app = App(state=test_state)
     assert app.state == test_state
 
     # Get a state for a given token.
-    token = "token"
-    state = app.state_manager.get_state(token)
+    state = await app.state_manager.get_state(token)
     assert isinstance(state, test_state)
     assert state.var == 0  # type: ignore
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
-def test_set_and_get_state(test_state):
+@pytest.mark.asyncio
+async def test_set_and_get_state(test_state):
     """Test setting and getting the state of an app with different tokens.
 
     Args:
@@ -338,47 +353,51 @@ def test_set_and_get_state(test_state):
     app = App(state=test_state)
 
     # Create two tokens.
-    token1 = "token1"
-    token2 = "token2"
+    token1 = str(uuid.uuid4())
+    token2 = str(uuid.uuid4())
 
     # Get the default state for each token.
-    state1 = app.state_manager.get_state(token1)
-    state2 = app.state_manager.get_state(token2)
+    state1 = await app.state_manager.get_state(token1)
+    state2 = await app.state_manager.get_state(token2)
     assert state1.var == 0  # type: ignore
     assert state2.var == 0  # type: ignore
 
     # Set the vars to different values.
     state1.var = 1
     state2.var = 2
-    app.state_manager.set_state(token1, state1)
-    app.state_manager.set_state(token2, state2)
+    await app.state_manager.set_state(token1, state1)
+    await app.state_manager.set_state(token2, state2)
 
     # Get the states again and check the values.
-    state1 = app.state_manager.get_state(token1)
-    state2 = app.state_manager.get_state(token2)
+    state1 = await app.state_manager.get_state(token1)
+    state2 = await app.state_manager.get_state(token2)
     assert state1.var == 1  # type: ignore
     assert state2.var == 2  # type: ignore
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 @pytest.mark.asyncio
-async def test_dynamic_var_event(test_state):
+async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
     """Test that the default handler of a dynamic generated var
     works as expected.
 
     Args:
         test_state: State Fixture.
+        token: a Token.
     """
-    test_state = test_state()
-    test_state.add_var("int_val", int, 0)
-    result = await test_state._process(
+    state = test_state()  # type: ignore
+    state.add_var("int_val", int, 0)
+    result = await state._process(
         Event(
-            token="fake-token",
-            name="test_state.set_int_val",
+            token=token,
+            name=f"{test_state.get_name()}.set_int_val",
             router_data={"pathname": "/", "query": {}},
             payload={"value": 50},
         )
     ).__anext__()
-    assert result.delta == {"test_state": {"int_val": 50}}
+    assert result.delta == {test_state.get_name(): {"int_val": 50}}
 
 
 @pytest.mark.asyncio
@@ -388,12 +407,20 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
             [
                 (
-                    "test_state.make_friend",
-                    {"test_state": {"plain_friends": ["Tommy", "another-fd"]}},
+                    "list_mutation_test_state.make_friend",
+                    {
+                        "list_mutation_test_state": {
+                            "plain_friends": ["Tommy", "another-fd"]
+                        }
+                    },
                 ),
                 (
-                    "test_state.change_first_friend",
-                    {"test_state": {"plain_friends": ["Jenny", "another-fd"]}},
+                    "list_mutation_test_state.change_first_friend",
+                    {
+                        "list_mutation_test_state": {
+                            "plain_friends": ["Jenny", "another-fd"]
+                        }
+                    },
                 ),
             ],
             id="append then __setitem__",
@@ -401,12 +428,12 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
             [
                 (
-                    "test_state.unfriend_first_friend",
-                    {"test_state": {"plain_friends": []}},
+                    "list_mutation_test_state.unfriend_first_friend",
+                    {"list_mutation_test_state": {"plain_friends": []}},
                 ),
                 (
-                    "test_state.make_friend",
-                    {"test_state": {"plain_friends": ["another-fd"]}},
+                    "list_mutation_test_state.make_friend",
+                    {"list_mutation_test_state": {"plain_friends": ["another-fd"]}},
                 ),
             ],
             id="delitem then append",
@@ -414,20 +441,24 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
             [
                 (
-                    "test_state.make_friends_with_colleagues",
-                    {"test_state": {"plain_friends": ["Tommy", "Peter", "Jimmy"]}},
+                    "list_mutation_test_state.make_friends_with_colleagues",
+                    {
+                        "list_mutation_test_state": {
+                            "plain_friends": ["Tommy", "Peter", "Jimmy"]
+                        }
+                    },
                 ),
                 (
-                    "test_state.remove_tommy",
-                    {"test_state": {"plain_friends": ["Peter", "Jimmy"]}},
+                    "list_mutation_test_state.remove_tommy",
+                    {"list_mutation_test_state": {"plain_friends": ["Peter", "Jimmy"]}},
                 ),
                 (
-                    "test_state.remove_last_friend",
-                    {"test_state": {"plain_friends": ["Peter"]}},
+                    "list_mutation_test_state.remove_last_friend",
+                    {"list_mutation_test_state": {"plain_friends": ["Peter"]}},
                 ),
                 (
-                    "test_state.unfriend_all_friends",
-                    {"test_state": {"plain_friends": []}},
+                    "list_mutation_test_state.unfriend_all_friends",
+                    {"list_mutation_test_state": {"plain_friends": []}},
                 ),
             ],
             id="extend, remove, pop, clear",
@@ -435,24 +466,28 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
             [
                 (
-                    "test_state.add_jimmy_to_second_group",
+                    "list_mutation_test_state.add_jimmy_to_second_group",
                     {
-                        "test_state": {
+                        "list_mutation_test_state": {
                             "friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
                         }
                     },
                 ),
                 (
-                    "test_state.remove_first_person_from_first_group",
+                    "list_mutation_test_state.remove_first_person_from_first_group",
                     {
-                        "test_state": {
+                        "list_mutation_test_state": {
                             "friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
                         }
                     },
                 ),
                 (
-                    "test_state.remove_first_group",
-                    {"test_state": {"friends_in_nested_list": [["Jenny", "Jimmy"]]}},
+                    "list_mutation_test_state.remove_first_group",
+                    {
+                        "list_mutation_test_state": {
+                            "friends_in_nested_list": [["Jenny", "Jimmy"]]
+                        }
+                    },
                 ),
             ],
             id="nested list",
@@ -460,16 +495,24 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
             [
                 (
-                    "test_state.add_jimmy_to_tommy_friends",
-                    {"test_state": {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}},
+                    "list_mutation_test_state.add_jimmy_to_tommy_friends",
+                    {
+                        "list_mutation_test_state": {
+                            "friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}
+                        }
+                    },
                 ),
                 (
-                    "test_state.remove_jenny_from_tommy",
-                    {"test_state": {"friends_in_dict": {"Tommy": ["Jimmy"]}}},
+                    "list_mutation_test_state.remove_jenny_from_tommy",
+                    {
+                        "list_mutation_test_state": {
+                            "friends_in_dict": {"Tommy": ["Jimmy"]}
+                        }
+                    },
                 ),
                 (
-                    "test_state.tommy_has_no_fds",
-                    {"test_state": {"friends_in_dict": {"Tommy": []}}},
+                    "list_mutation_test_state.tommy_has_no_fds",
+                    {"list_mutation_test_state": {"friends_in_dict": {"Tommy": []}}},
                 ),
             ],
             id="list in dict",
@@ -477,7 +520,9 @@ async def test_dynamic_var_event(test_state):
     ],
 )
 async def test_list_mutation_detection__plain_list(
-    event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State
+    event_tuples: List[Tuple[str, List[str]]],
+    list_mutation_state: State,
+    token: str,
 ):
     """Test list mutation detection
     when reassignment is not explicitly included in the logic.
@@ -485,11 +530,12 @@ async def test_list_mutation_detection__plain_list(
     Args:
         event_tuples: From parametrization.
         list_mutation_state: A state with list mutation features.
+        token: a Token.
     """
     for event_name, expected_delta in event_tuples:
         result = await list_mutation_state._process(
             Event(
-                token="fake-token",
+                token=token,
                 name=event_name,
                 router_data={"pathname": "/", "query": {}},
                 payload={},
@@ -506,16 +552,24 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
             [
                 (
-                    "test_state.add_age",
-                    {"test_state": {"details": {"name": "Tommy", "age": 20}}},
+                    "dict_mutation_test_state.add_age",
+                    {
+                        "dict_mutation_test_state": {
+                            "details": {"name": "Tommy", "age": 20}
+                        }
+                    },
                 ),
                 (
-                    "test_state.change_name",
-                    {"test_state": {"details": {"name": "Jenny", "age": 20}}},
+                    "dict_mutation_test_state.change_name",
+                    {
+                        "dict_mutation_test_state": {
+                            "details": {"name": "Jenny", "age": 20}
+                        }
+                    },
                 ),
                 (
-                    "test_state.remove_last_detail",
-                    {"test_state": {"details": {"name": "Jenny"}}},
+                    "dict_mutation_test_state.remove_last_detail",
+                    {"dict_mutation_test_state": {"details": {"name": "Jenny"}}},
                 ),
             ],
             id="update then __setitem__",
@@ -523,12 +577,12 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
             [
                 (
-                    "test_state.clear_details",
-                    {"test_state": {"details": {}}},
+                    "dict_mutation_test_state.clear_details",
+                    {"dict_mutation_test_state": {"details": {}}},
                 ),
                 (
-                    "test_state.add_age",
-                    {"test_state": {"details": {"age": 20}}},
+                    "dict_mutation_test_state.add_age",
+                    {"dict_mutation_test_state": {"details": {"age": 20}}},
                 ),
             ],
             id="delitem then update",
@@ -536,16 +590,20 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
             [
                 (
-                    "test_state.add_age",
-                    {"test_state": {"details": {"name": "Tommy", "age": 20}}},
+                    "dict_mutation_test_state.add_age",
+                    {
+                        "dict_mutation_test_state": {
+                            "details": {"name": "Tommy", "age": 20}
+                        }
+                    },
                 ),
                 (
-                    "test_state.remove_name",
-                    {"test_state": {"details": {"age": 20}}},
+                    "dict_mutation_test_state.remove_name",
+                    {"dict_mutation_test_state": {"details": {"age": 20}}},
                 ),
                 (
-                    "test_state.pop_out_age",
-                    {"test_state": {"details": {}}},
+                    "dict_mutation_test_state.pop_out_age",
+                    {"dict_mutation_test_state": {"details": {}}},
                 ),
             ],
             id="add, remove, pop",
@@ -553,13 +611,17 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
             [
                 (
-                    "test_state.remove_home_address",
-                    {"test_state": {"address": [{}, {"work": "work address"}]}},
+                    "dict_mutation_test_state.remove_home_address",
+                    {
+                        "dict_mutation_test_state": {
+                            "address": [{}, {"work": "work address"}]
+                        }
+                    },
                 ),
                 (
-                    "test_state.add_street_to_home_address",
+                    "dict_mutation_test_state.add_street_to_home_address",
                     {
-                        "test_state": {
+                        "dict_mutation_test_state": {
                             "address": [
                                 {"street": "street address"},
                                 {"work": "work address"},
@@ -573,9 +635,9 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
             [
                 (
-                    "test_state.change_friend_name",
+                    "dict_mutation_test_state.change_friend_name",
                     {
-                        "test_state": {
+                        "dict_mutation_test_state": {
                             "friend_in_nested_dict": {
                                 "name": "Nikhil",
                                 "friend": {"name": "Tommy"},
@@ -584,9 +646,9 @@ async def test_list_mutation_detection__plain_list(
                     },
                 ),
                 (
-                    "test_state.add_friend_age",
+                    "dict_mutation_test_state.add_friend_age",
                     {
-                        "test_state": {
+                        "dict_mutation_test_state": {
                             "friend_in_nested_dict": {
                                 "name": "Nikhil",
                                 "friend": {"name": "Tommy", "age": 30},
@@ -595,8 +657,12 @@ async def test_list_mutation_detection__plain_list(
                     },
                 ),
                 (
-                    "test_state.remove_friend",
-                    {"test_state": {"friend_in_nested_dict": {"name": "Nikhil"}}},
+                    "dict_mutation_test_state.remove_friend",
+                    {
+                        "dict_mutation_test_state": {
+                            "friend_in_nested_dict": {"name": "Nikhil"}
+                        }
+                    },
                 ),
             ],
             id="nested dict",
@@ -604,7 +670,9 @@ async def test_list_mutation_detection__plain_list(
     ],
 )
 async def test_dict_mutation_detection__plain_list(
-    event_tuples: List[Tuple[str, List[str]]], dict_mutation_state: State
+    event_tuples: List[Tuple[str, List[str]]],
+    dict_mutation_state: State,
+    token: str,
 ):
     """Test dict mutation detection
     when reassignment is not explicitly included in the logic.
@@ -612,11 +680,12 @@ async def test_dict_mutation_detection__plain_list(
     Args:
         event_tuples: From parametrization.
         dict_mutation_state: A state with dict mutation features.
+        token: a Token.
     """
     for event_name, expected_delta in event_tuples:
         result = await dict_mutation_state._process(
             Event(
-                token="fake-token",
+                token=token,
                 name=event_name,
                 router_data={"pathname": "/", "query": {}},
                 payload={},
@@ -628,41 +697,43 @@ async def test_dict_mutation_detection__plain_list(
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize(
-    "fixture, delta",
+    ("state", "delta"),
     [
         (
-            "upload_state",
+            FileUploadState,
             {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
         ),
         (
-            "upload_sub_state",
+            ChildFileUploadState,
             {
-                "file_state.file_upload_state": {
+                "file_state_base1.child_file_upload_state": {
                     "img_list": ["image1.jpg", "image2.jpg"]
                 }
             },
         ),
         (
-            "upload_grand_sub_state",
+            GrandChildFileUploadState,
             {
-                "base_file_state.file_sub_state.file_upload_state": {
+                "file_state_base1.file_state_base2.grand_child_file_upload_state": {
                     "img_list": ["image1.jpg", "image2.jpg"]
                 }
             },
         ),
     ],
 )
-async def test_upload_file(fixture, request, delta):
+async def test_upload_file(tmp_path, state, delta, token: str):
     """Test that file upload works correctly.
 
     Args:
-        fixture: The state.
-        request: Fixture request.
+        tmp_path: Temporary path.
+        state: The state class.
         delta: Expected delta
+        token: a Token.
     """
-    app = App(state=request.getfixturevalue(fixture))
+    state._tmp_path = tmp_path
+    app = App(state=state)
     app.event_namespace.emit = AsyncMock()  # type: ignore
-    current_state = app.state_manager.get_state("token")
+    current_state = await app.state_manager.get_state(token)
     data = b"This is binary data"
 
     # Create a binary IO object and write data to it
@@ -670,11 +741,11 @@ async def test_upload_file(fixture, request, delta):
     bio.write(data)
 
     file1 = UploadFile(
-        filename="token:file_upload_state.multi_handle_upload:True:image1.jpg",
+        filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image1.jpg",
         file=bio,
     )
     file2 = UploadFile(
-        filename="token:file_upload_state.multi_handle_upload:True:image2.jpg",
+        filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image2.jpg",
         file=bio,
     )
     upload_fn = upload(app)
@@ -684,22 +755,27 @@ async def test_upload_file(fixture, request, delta):
     app.event_namespace.emit.assert_called_with(  # type: ignore
         "event", state_update.json(), to=current_state.get_sid()
     )
-    assert app.state_manager.get_state("token").dict()["img_list"] == [
+    assert (await app.state_manager.get_state(token)).dict()["img_list"] == [
         "image1.jpg",
         "image2.jpg",
     ]
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize(
-    "fixture", ["upload_state", "upload_sub_state", "upload_grand_sub_state"]
+    "state",
+    [FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
 )
-async def test_upload_file_without_annotation(fixture, request):
+async def test_upload_file_without_annotation(state, tmp_path, token):
     """Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile].
 
     Args:
-        fixture: The state.
-        request: Fixture request.
+        state: The state class.
+        tmp_path: Temporary path.
+        token: a Token.
     """
     data = b"This is binary data"
 
@@ -707,14 +783,15 @@ async def test_upload_file_without_annotation(fixture, request):
     bio = io.BytesIO()
     bio.write(data)
 
-    app = App(state=request.getfixturevalue(fixture))
+    state._tmp_path = tmp_path
+    app = App(state=state)
 
     file1 = UploadFile(
-        filename="token:file_upload_state.handle_upload2:True:image1.jpg",
+        filename=f"{token}:{state.get_name()}.handle_upload2:True:image1.jpg",
         file=bio,
     )
     file2 = UploadFile(
-        filename="token:file_upload_state.handle_upload2:True:image2.jpg",
+        filename=f"{token}:{state.get_name()}.handle_upload2:True:image2.jpg",
         file=bio,
     )
     fn = upload(app)
@@ -722,9 +799,12 @@ async def test_upload_file_without_annotation(fixture, request):
         await fn([file1, file2])
     assert (
         err.value.args[0]
-        == "`file_upload_state.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
+        == f"`{state.get_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
     )
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 class DynamicState(State):
     """State class for testing dynamic route var.
@@ -768,6 +848,7 @@ class DynamicState(State):
 async def test_dynamic_route_var_route_change_completed_on_load(
     index_page,
     windows_platform: bool,
+    token: str,
 ):
     """Create app with dynamic route var, and simulate navigation.
 
@@ -777,6 +858,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     Args:
         index_page: The index page.
         windows_platform: Whether the system is windows.
+        token: a Token.
     """
     arg_name = "dynamic"
     route = f"/test/[{arg_name}]"
@@ -792,10 +874,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     }
     assert constants.ROUTER_DATA in app.state().computed_var_dependencies
 
-    token = "mock_token"
     sid = "mock_sid"
     client_ip = "127.0.0.1"
-    state = app.state_manager.get_state(token)
+    state = await app.state_manager.get_state(token)
     assert state.dynamic == ""
     exp_vals = ["foo", "foobar", "baz"]
 
@@ -817,6 +898,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             **kwargs,
         )
 
+    prev_exp_val = ""
     for exp_index, exp_val in enumerate(exp_vals):
         hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
         exp_router_data = {
@@ -826,13 +908,14 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             "token": token,
             **hydrate_event.router_data,
         }
-        update = await process(
+        process_coro = process(
             app,
             event=hydrate_event,
             sid=sid,
             headers={},
             client_ip=client_ip,
-        ).__anext__()  # type: ignore
+        )
+        update = await process_coro.__anext__()  # type: ignore
 
         # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
         assert update == StateUpdate(
@@ -860,14 +943,27 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                 ),
             ],
         )
+        if isinstance(app.state_manager, StateManagerRedis):
+            # When redis is used, the state is not updated until the processing is complete
+            state = await app.state_manager.get_state(token)
+            assert state.dynamic == prev_exp_val
+
+        # complete the processing
+        with pytest.raises(StopAsyncIteration):
+            await process_coro.__anext__()  # type: ignore
+
+        # check that router data was written to the state_manager store
+        state = await app.state_manager.get_state(token)
         assert state.dynamic == exp_val
-        on_load_update = await process(
+
+        process_coro = process(
             app,
             event=_dynamic_state_event(name="on_load", val=exp_val),
             sid=sid,
             headers={},
             client_ip=client_ip,
-        ).__anext__()  # type: ignore
+        )
+        on_load_update = await process_coro.__anext__()  # type: ignore
         assert on_load_update == StateUpdate(
             delta={
                 state.get_name(): {
@@ -879,7 +975,10 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             },
             events=[],
         )
-        on_set_is_hydrated_update = await process(
+        # complete the processing
+        with pytest.raises(StopAsyncIteration):
+            await process_coro.__anext__()  # type: ignore
+        process_coro = process(
             app,
             event=_dynamic_state_event(
                 name="set_is_hydrated", payload={"value": True}, val=exp_val
@@ -887,7 +986,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             sid=sid,
             headers={},
             client_ip=client_ip,
-        ).__anext__()  # type: ignore
+        )
+        on_set_is_hydrated_update = await process_coro.__anext__()  # type: ignore
         assert on_set_is_hydrated_update == StateUpdate(
             delta={
                 state.get_name(): {
@@ -899,15 +999,19 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             },
             events=[],
         )
+        # complete the processing
+        with pytest.raises(StopAsyncIteration):
+            await process_coro.__anext__()  # type: ignore
 
         # a simple state update event should NOT trigger on_load or route var side effects
-        update = await process(
+        process_coro = process(
             app,
             event=_dynamic_state_event(name="on_counter", val=exp_val),
             sid=sid,
             headers={},
             client_ip=client_ip,
-        ).__anext__()  # type: ignore
+        )
+        update = await process_coro.__anext__()  # type: ignore
         assert update == StateUpdate(
             delta={
                 state.get_name(): {
@@ -919,42 +1023,54 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             },
             events=[],
         )
+        # complete the processing
+        with pytest.raises(StopAsyncIteration):
+            await process_coro.__anext__()  # type: ignore
+
+        prev_exp_val = exp_val
+    state = await app.state_manager.get_state(token)
     assert state.loaded == len(exp_vals)
     assert state.counter == len(exp_vals)
     # print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
     # assert state.side_effect_counter == len(exp_vals)
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 @pytest.mark.asyncio
-async def test_process_events(gen_state, mocker):
+async def test_process_events(mocker, token: str):
     """Test that an event is processed properly and that it is postprocessed
     n+1 times. Also check that the processing flag of the last stateupdate is set to
     False.
 
     Args:
-        gen_state: The state.
         mocker: mocker object.
+        token: a Token.
     """
     router_data = {
         "pathname": "/",
         "query": {},
-        "token": "mock_token",
+        "token": token,
         "sid": "mock_sid",
         "headers": {},
         "ip": "127.0.0.1",
     }
-    app = App(state=gen_state)
+    app = App(state=GenState)
     mocker.patch.object(app, "postprocess", AsyncMock())
     event = Event(
-        token="token", name="gen_state.go", payload={"c": 5}, router_data=router_data
+        token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
     )
 
     async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):  # type: ignore
         pass
 
-    assert app.state_manager.get_state("token").value == 5
+    assert (await app.state_manager.get_state(token)).value == 5
     assert app.postprocess.call_count == 6
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 @pytest.mark.parametrize(
     ("state", "overlay_component", "exp_page_child"),

+ 417 - 11
tests/test_state.py

@@ -1,22 +1,42 @@
 from __future__ import annotations
 
+import asyncio
 import copy
 import datetime
 import functools
+import json
+import os
 import sys
-from typing import Dict, List
+from typing import Dict, Generator, List
+from unittest.mock import AsyncMock, Mock
 
 import pytest
 from plotly.graph_objects import Figure
 
 import reflex as rx
 from reflex.base import Base
-from reflex.constants import IS_HYDRATED, RouteVar
+from reflex.constants import APP_VAR, IS_HYDRATED, RouteVar, SocketEvent
 from reflex.event import Event, EventHandler
-from reflex.state import MutableProxy, State
-from reflex.utils import format
+from reflex.state import (
+    ImmutableStateError,
+    LockExpiredError,
+    MutableProxy,
+    State,
+    StateManager,
+    StateManagerMemory,
+    StateManagerRedis,
+    StateProxy,
+    StateUpdate,
+)
+from reflex.utils import format, prerequisites
 from reflex.vars import BaseVar, ComputedVar
 
+from .states import GenState
+
+CI = bool(os.environ.get("CI", False))
+LOCK_EXPIRATION = 2000 if CI else 100
+LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
+
 
 class Object(Base):
     """A test object fixture."""
@@ -704,13 +724,9 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
 
 
 @pytest.mark.asyncio
-async def test_process_event_generator(gen_state):
-    """Test event handlers that generate multiple updates.
-
-    Args:
-        gen_state: A state.
-    """
-    gen_state = gen_state()
+async def test_process_event_generator():
+    """Test event handlers that generate multiple updates."""
+    gen_state = GenState()  # type: ignore
     event = Event(
         token="t",
         name="go",
@@ -1402,6 +1418,396 @@ def test_state_with_invalid_yield():
     )
 
 
+@pytest.fixture(scope="function", params=["in_process", "redis"])
+def state_manager(request) -> Generator[StateManager, None, None]:
+    """Instance of state manager parametrized for redis and in-process.
+
+    Args:
+        request: pytest request object.
+
+    Yields:
+        A state manager instance
+    """
+    state_manager = StateManager.create(state=TestState)
+    if request.param == "redis":
+        if not isinstance(state_manager, StateManagerRedis):
+            pytest.skip("Test requires redis")
+    else:
+        # explicitly NOT using redis
+        state_manager = StateManagerMemory(state=TestState)
+        assert not state_manager._states_locks
+
+    yield state_manager
+
+    if isinstance(state_manager, StateManagerRedis):
+        asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
+
+
+@pytest.mark.asyncio
+async def test_state_manager_modify_state(state_manager: StateManager, token: str):
+    """Test that the state manager can modify a state exclusively.
+
+    Args:
+        state_manager: A state manager instance.
+        token: A token.
+    """
+    async with state_manager.modify_state(token):
+        if isinstance(state_manager, StateManagerRedis):
+            assert await state_manager.redis.get(f"{token}_lock")
+        elif isinstance(state_manager, StateManagerMemory):
+            assert token in state_manager._states_locks
+            assert state_manager._states_locks[token].locked()
+    # lock should be dropped after exiting the context
+    if isinstance(state_manager, StateManagerRedis):
+        assert (await state_manager.redis.get(f"{token}_lock")) is None
+    elif isinstance(state_manager, StateManagerMemory):
+        assert not state_manager._states_locks[token].locked()
+
+        # separate instances should NOT share locks
+        sm2 = StateManagerMemory(state=TestState)
+        assert sm2._state_manager_lock is state_manager._state_manager_lock
+        assert not sm2._states_locks
+        if state_manager._states_locks:
+            assert sm2._states_locks != state_manager._states_locks
+
+
+@pytest.mark.asyncio
+async def test_state_manager_contend(state_manager: StateManager, token: str):
+    """Multiple coroutines attempting to access the same state.
+
+    Args:
+        state_manager: A state manager instance.
+        token: A token.
+    """
+    n_coroutines = 10
+    exp_num1 = 10
+
+    async with state_manager.modify_state(token) as state:
+        state.num1 = 0
+
+    async def _coro():
+        async with state_manager.modify_state(token) as state:
+            await asyncio.sleep(0.01)
+            state.num1 += 1
+
+    tasks = [asyncio.create_task(_coro()) for _ in range(n_coroutines)]
+
+    for f in asyncio.as_completed(tasks):
+        await f
+
+    assert (await state_manager.get_state(token)).num1 == exp_num1
+
+    if isinstance(state_manager, StateManagerRedis):
+        assert (await state_manager.redis.get(f"{token}_lock")) is None
+    elif isinstance(state_manager, StateManagerMemory):
+        assert token in state_manager._states_locks
+        assert not state_manager._states_locks[token].locked()
+
+
+@pytest.fixture(scope="function")
+def state_manager_redis() -> Generator[StateManager, None, None]:
+    """Instance of state manager for redis only.
+
+    Yields:
+        A state manager instance
+    """
+    state_manager = StateManager.create(TestState)
+
+    if not isinstance(state_manager, StateManagerRedis):
+        pytest.skip("Test requires redis")
+
+    yield state_manager
+
+    asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
+
+
+@pytest.mark.asyncio
+async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str):
+    """Test that the state manager lock expires and raises exception exiting context.
+
+    Args:
+        state_manager_redis: A state manager instance.
+        token: A token.
+    """
+    state_manager_redis.lock_expiration = LOCK_EXPIRATION
+
+    async with state_manager_redis.modify_state(token):
+        await asyncio.sleep(0.01)
+
+    with pytest.raises(LockExpiredError):
+        async with state_manager_redis.modify_state(token):
+            await asyncio.sleep(LOCK_EXPIRE_SLEEP)
+
+
+@pytest.mark.asyncio
+async def test_state_manager_lock_expire_contend(
+    state_manager_redis: StateManager, token: str
+):
+    """Test that the state manager lock expires and queued waiters proceed.
+
+    Args:
+        state_manager_redis: A state manager instance.
+        token: A token.
+    """
+    exp_num1 = 4252
+    unexp_num1 = 666
+
+    state_manager_redis.lock_expiration = LOCK_EXPIRATION
+
+    order = []
+
+    async def _coro_blocker():
+        async with state_manager_redis.modify_state(token) as state:
+            order.append("blocker")
+            await asyncio.sleep(LOCK_EXPIRE_SLEEP)
+            state.num1 = unexp_num1
+
+    async def _coro_waiter():
+        while "blocker" not in order:
+            await asyncio.sleep(0.005)
+        async with state_manager_redis.modify_state(token) as state:
+            order.append("waiter")
+            assert state.num1 != unexp_num1
+            state.num1 = exp_num1
+
+    tasks = [
+        asyncio.create_task(_coro_blocker()),
+        asyncio.create_task(_coro_waiter()),
+    ]
+    with pytest.raises(LockExpiredError):
+        await tasks[0]
+    await tasks[1]
+
+    assert order == ["blocker", "waiter"]
+    assert (await state_manager_redis.get_state(token)).num1 == exp_num1
+
+
+@pytest.fixture(scope="function")
+def mock_app(monkeypatch, app: rx.App, state_manager: StateManager) -> rx.App:
+    """Mock app fixture.
+
+    Args:
+        monkeypatch: Pytest monkeypatch object.
+        app: An app.
+        state_manager: A state manager.
+
+    Returns:
+        The app, after mocking out prerequisites.get_app()
+    """
+    app_module = Mock()
+    setattr(app_module, APP_VAR, app)
+    app.state = TestState
+    app.state_manager = state_manager
+    assert app.event_namespace is not None
+    app.event_namespace.emit = AsyncMock()
+    monkeypatch.setattr(prerequisites, "get_app", lambda: app_module)
+    return app
+
+
+@pytest.mark.asyncio
+async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
+    """Test that the state proxy works.
+
+    Args:
+        grandchild_state: A grandchild state.
+        mock_app: An app that will be returned by `get_app()`
+    """
+    child_state = grandchild_state.parent_state
+    assert child_state is not None
+    parent_state = child_state.parent_state
+    assert parent_state is not None
+    if isinstance(mock_app.state_manager, StateManagerMemory):
+        mock_app.state_manager.states[parent_state.get_token()] = parent_state
+
+    sp = StateProxy(grandchild_state)
+    assert sp.__wrapped__ == grandchild_state
+    assert sp._self_substate_path == grandchild_state.get_full_name().split(".")
+    assert sp._self_app is mock_app
+    assert not sp._self_mutable
+    assert sp._self_actx is None
+
+    # cannot use normal contextmanager protocol
+    with pytest.raises(TypeError), sp:
+        pass
+
+    with pytest.raises(ImmutableStateError):
+        # cannot directly modify state proxy outside of async context
+        sp.value2 = 16
+
+    async with sp:
+        assert sp._self_actx is not None
+        assert sp._self_mutable  # proxy is mutable inside context
+        if isinstance(mock_app.state_manager, StateManagerMemory):
+            # For in-process store, only one instance of the state exists
+            assert sp.__wrapped__ is grandchild_state
+        else:
+            # When redis is used, a new+updated instance is assigned to the proxy
+            assert sp.__wrapped__ is not grandchild_state
+        sp.value2 = 42
+    assert not sp._self_mutable  # proxy is not mutable after exiting context
+    assert sp._self_actx is None
+    assert sp.value2 == 42
+
+    # Get the state from the state manager directly and check that the value is updated
+    gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token())
+    if isinstance(mock_app.state_manager, StateManagerMemory):
+        # For in-process store, only one instance of the state exists
+        assert gotten_state is parent_state
+    else:
+        assert gotten_state is not parent_state
+    gotten_grandchild_state = gotten_state.get_substate(sp._self_substate_path)
+    assert gotten_grandchild_state is not None
+    assert gotten_grandchild_state.value2 == 42
+
+    # ensure state update was emitted
+    assert mock_app.event_namespace is not None
+    mock_app.event_namespace.emit.assert_called_once()
+    mcall = mock_app.event_namespace.emit.mock_calls[0]
+    assert mcall.args[0] == str(SocketEvent.EVENT)
+    assert json.loads(mcall.args[1]) == StateUpdate(
+        delta={
+            parent_state.get_full_name(): {
+                "upper": "",
+                "sum": 3.14,
+            },
+            grandchild_state.get_full_name(): {
+                "value2": 42,
+            },
+        }
+    )
+    assert mcall.kwargs["to"] == grandchild_state.get_sid()
+
+
+class BackgroundTaskState(State):
+    """A state with a background task."""
+
+    order: List[str] = []
+    dict_list: Dict[str, List[int]] = {"foo": []}
+
+    @rx.background
+    async def background_task(self):
+        """A background task that updates the state."""
+        async with self:
+            assert not self.order
+            self.order.append("background_task:start")
+
+        assert isinstance(self, StateProxy)
+        with pytest.raises(ImmutableStateError):
+            self.order.append("bad idea")
+
+        with pytest.raises(ImmutableStateError):
+            # Even nested access to mutables raises an exception.
+            self.dict_list["foo"].append(42)
+
+        # wait for some other event to happen
+        while len(self.order) == 1:
+            await asyncio.sleep(0.01)
+            async with self:
+                pass  # update proxy instance
+
+        async with self:
+            self.order.append("background_task:stop")
+
+    @rx.background
+    async def background_task_generator(self):
+        """A background task generator that does nothing.
+
+        Yields:
+            None
+        """
+        yield
+
+    def other(self):
+        """Some other event that updates the state."""
+        self.order.append("other")
+
+    async def bad_chain1(self):
+        """Test that a background task cannot be chained."""
+        await self.background_task()
+
+    async def bad_chain2(self):
+        """Test that a background task generator cannot be chained."""
+        async for _foo in self.background_task_generator():
+            pass
+
+
+@pytest.mark.asyncio
+async def test_background_task_no_block(mock_app: rx.App, token: str):
+    """Test that a background task does not block other events.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+    router_data = {"query": {}}
+    mock_app.state_manager.state = mock_app.state = BackgroundTaskState
+    async for update in rx.app.process(  # type: ignore
+        mock_app,
+        Event(
+            token=token,
+            name=f"{BackgroundTaskState.get_name()}.background_task",
+            router_data=router_data,
+            payload={},
+        ),
+        sid="",
+        headers={},
+        client_ip="",
+    ):
+        # background task returns empty update immediately
+        assert update == StateUpdate()
+    assert len(mock_app.background_tasks) == 1
+
+    # wait for the coroutine to start
+    await asyncio.sleep(0.5 if CI else 0.1)
+    assert len(mock_app.background_tasks) == 1
+
+    # Process another normal event
+    async for update in rx.app.process(  # type: ignore
+        mock_app,
+        Event(
+            token=token,
+            name=f"{BackgroundTaskState.get_name()}.other",
+            router_data=router_data,
+            payload={},
+        ),
+        sid="",
+        headers={},
+        client_ip="",
+    ):
+        # other task returns delta
+        assert update == StateUpdate(
+            delta={
+                BackgroundTaskState.get_name(): {
+                    "order": [
+                        "background_task:start",
+                        "other",
+                    ],
+                }
+            }
+        )
+
+    # Explicit wait for background tasks
+    for task in tuple(mock_app.background_tasks):
+        await task
+    assert not mock_app.background_tasks
+
+    assert (await mock_app.state_manager.get_state(token)).order == [
+        "background_task:start",
+        "other",
+        "background_task:stop",
+    ]
+
+
+@pytest.mark.asyncio
+async def test_background_task_no_chain():
+    """Test that a background task cannot be chained."""
+    bts = BackgroundTaskState()
+    with pytest.raises(RuntimeError):
+        await bts.bad_chain1()
+    with pytest.raises(RuntimeError):
+        await bts.bad_chain2()
+
+
 def test_mutable_list(mutable_state):
     """Test that mutable lists are tracked correctly.