Ver Fonte

rx.background and StateManager.modify_state provides safe exclusive access to state (#1676)

Masen Furer há 1 ano atrás
pai
commit
351611ca25

+ 17 - 0
.github/workflows/integration_app_harness.yml

@@ -15,7 +15,23 @@ permissions:
 
 
 jobs:
 jobs:
   integration-app-harness:
   integration-app-harness:
+    strategy:
+      matrix:
+        state_manager: [ "redis", "memory" ]
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
+    services:
+      # Label used to access the service container
+      redis:
+        image: ${{ matrix.state_manager == 'redis' && 'redis' || '' }}
+        # Set health checks to wait until redis has started
+        options: >-
+          --health-cmd "redis-cli ping"
+          --health-interval 10s
+          --health-timeout 5s
+          --health-retries 5
+        ports:
+          # Maps port 6379 on service container to the host
+          - 6379:6379
     steps:
     steps:
     - uses: actions/checkout@v4
     - uses: actions/checkout@v4
     - uses: ./.github/actions/setup_build_env
     - uses: ./.github/actions/setup_build_env
@@ -27,6 +43,7 @@ jobs:
     - name: Run app harness tests
     - name: Run app harness tests
       env:
       env:
         SCREENSHOT_DIR: /tmp/screenshots
         SCREENSHOT_DIR: /tmp/screenshots
+        REDIS_URL: ${{ matrix.state_manager == 'redis' && 'localhost:6379' || '' }}
       run: |
       run: |
         poetry run pytest integration
         poetry run pytest integration
     - uses: actions/upload-artifact@v3
     - uses: actions/upload-artifact@v3

+ 20 - 0
.github/workflows/unit_tests.yml

@@ -40,6 +40,20 @@ jobs:
           - os: windows-latest
           - os: windows-latest
             python-version: "3.8.10"
             python-version: "3.8.10"
     runs-on: ${{ matrix.os }}
     runs-on: ${{ matrix.os }}
+    # Service containers to run with `runner-job`
+    services:
+      # Label used to access the service container
+      redis:
+        image: ${{ matrix.os == 'ubuntu-latest' && 'redis' || '' }}
+        # Set health checks to wait until redis has started
+        options: >-
+          --health-cmd "redis-cli ping"
+          --health-interval 10s
+          --health-timeout 5s
+          --health-retries 5
+        ports:
+          # Maps port 6379 on service container to the host
+          - 6379:6379
     steps:
     steps:
     - uses: actions/checkout@v4
     - uses: actions/checkout@v4
     - uses: ./.github/actions/setup_build_env
     - uses: ./.github/actions/setup_build_env
@@ -51,4 +65,10 @@ jobs:
       run: |
       run: |
         export PYTHONUNBUFFERED=1
         export PYTHONUNBUFFERED=1
         poetry run pytest tests --cov --no-cov-on-fail --cov-report=
         poetry run pytest tests --cov --no-cov-on-fail --cov-report=
+    - name: Run unit tests w/ redis
+      if: ${{ matrix.os == 'ubuntu-latest' }}
+      run: |
+        export PYTHONUNBUFFERED=1
+        export REDIS_URL=localhost:6379
+        poetry run pytest tests --cov --no-cov-on-fail --cov-report=
     - run: poetry run coverage html
     - run: poetry run coverage html

+ 6 - 6
.pre-commit-config.yaml

@@ -1,4 +1,10 @@
 repos:
 repos:
+  - repo: https://github.com/psf/black
+    rev: 22.10.0
+    hooks:
+    - id: black
+      args: [integration, reflex, tests]
+
   - repo: https://github.com/charliermarsh/ruff-pre-commit
   - repo: https://github.com/charliermarsh/ruff-pre-commit
     rev: v0.0.244
     rev: v0.0.244
     hooks:
     hooks:
@@ -17,9 +23,3 @@ repos:
     hooks:
     hooks:
     - id: darglint
     - id: darglint
       exclude: '^reflex/reflex.py'
       exclude: '^reflex/reflex.py'
-
-  - repo: https://github.com/psf/black
-    rev: 22.10.0
-    hooks:
-    - id: black
-      args: [integration, reflex, tests]

+ 214 - 0
integration/test_background_task.py

@@ -0,0 +1,214 @@
+"""Test @rx.background task functionality."""
+
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
+
+
+def BackgroundTask():
+    """Test that background tasks work as expected."""
+    import asyncio
+
+    import reflex as rx
+
+    class State(rx.State):
+        counter: int = 0
+        _task_id: int = 0
+        iterations: int = 10
+
+        @rx.background
+        async def handle_event(self):
+            async with self:
+                self._task_id += 1
+            for _ix in range(int(self.iterations)):
+                async with self:
+                    self.counter += 1
+                await asyncio.sleep(0.005)
+
+        @rx.background
+        async def handle_event_yield_only(self):
+            async with self:
+                self._task_id += 1
+            for ix in range(int(self.iterations)):
+                if ix % 2 == 0:
+                    yield State.increment_arbitrary(1)  # type: ignore
+                else:
+                    yield State.increment()  # type: ignore
+                await asyncio.sleep(0.005)
+
+        def increment(self):
+            self.counter += 1
+
+        @rx.background
+        async def increment_arbitrary(self, amount: int):
+            async with self:
+                self.counter += int(amount)
+
+        def reset_counter(self):
+            self.counter = 0
+
+        async def blocking_pause(self):
+            await asyncio.sleep(0.02)
+
+        @rx.background
+        async def non_blocking_pause(self):
+            await asyncio.sleep(0.02)
+
+        @rx.cached_var
+        def token(self) -> str:
+            return self.get_token()
+
+    def index() -> rx.Component:
+        return rx.vstack(
+            rx.input(id="token", value=State.token, is_read_only=True),
+            rx.heading(State.counter, id="counter"),
+            rx.input(
+                id="iterations",
+                placeholder="Iterations",
+                value=State.iterations.to_string(),  # type: ignore
+                on_change=State.set_iterations,  # type: ignore
+            ),
+            rx.button(
+                "Delayed Increment",
+                on_click=State.handle_event,
+                id="delayed-increment",
+            ),
+            rx.button(
+                "Yield Increment",
+                on_click=State.handle_event_yield_only,
+                id="yield-increment",
+            ),
+            rx.button("Increment 1", on_click=State.increment, id="increment"),
+            rx.button(
+                "Blocking Pause",
+                on_click=State.blocking_pause,
+                id="blocking-pause",
+            ),
+            rx.button(
+                "Non-Blocking Pause",
+                on_click=State.non_blocking_pause,
+                id="non-blocking-pause",
+            ),
+            rx.button("Reset", on_click=State.reset_counter, id="reset"),
+        )
+
+    app = rx.App(state=State)
+    app.add_page(index)
+    app.compile()
+
+
+@pytest.fixture(scope="session")
+def background_task(
+    tmp_path_factory,
+) -> Generator[AppHarness, None, None]:
+    """Start BackgroundTask app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path_factory: pytest tmp_path_factory fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path_factory.mktemp(f"background_task"),
+        app_source=BackgroundTask,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.fixture
+def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
+    """Get an instance of the browser open to the background_task app.
+
+    Args:
+        background_task: harness for BackgroundTask app
+
+    Yields:
+        WebDriver instance.
+    """
+    assert background_task.app_instance is not None, "app is not running"
+    driver = background_task.frontend()
+    try:
+        yield driver
+    finally:
+        driver.quit()
+
+
+@pytest.fixture()
+def token(background_task: AppHarness, driver: WebDriver) -> str:
+    """Get a function that returns the active token.
+
+    Args:
+        background_task: harness for BackgroundTask app.
+        driver: WebDriver instance.
+
+    Returns:
+        The token for the connected client
+    """
+    assert background_task.app_instance is not None
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
+    assert token is not None
+
+    return token
+
+
+def test_background_task(
+    background_task: AppHarness,
+    driver: WebDriver,
+    token: str,
+):
+    """Test that background tasks work as expected.
+
+    Args:
+        background_task: harness for BackgroundTask app.
+        driver: WebDriver instance.
+        token: The token for the connected client.
+    """
+    assert background_task.app_instance is not None
+
+    # get a reference to all buttons
+    delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
+    yield_increment_button = driver.find_element(By.ID, "yield-increment")
+    increment_button = driver.find_element(By.ID, "increment")
+    blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
+    non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
+    driver.find_element(By.ID, "reset")
+
+    # get a reference to the counter
+    counter = driver.find_element(By.ID, "counter")
+
+    # get a reference to the iterations input
+    iterations_input = driver.find_element(By.ID, "iterations")
+
+    # kick off background tasks
+    iterations_input.clear()
+    iterations_input.send_keys("50")
+    delayed_increment_button.click()
+    blocking_pause_button.click()
+    delayed_increment_button.click()
+    for _ in range(10):
+        increment_button.click()
+    blocking_pause_button.click()
+    delayed_increment_button.click()
+    delayed_increment_button.click()
+    yield_increment_button.click()
+    non_blocking_pause_button.click()
+    yield_increment_button.click()
+    blocking_pause_button.click()
+    yield_increment_button.click()
+    for _ in range(10):
+        increment_button.click()
+    yield_increment_button.click()
+    blocking_pause_button.click()
+    assert background_task._poll_for(lambda: counter.text == "420", timeout=40)
+    # all tasks should have exited and cleaned up
+    assert background_task._poll_for(
+        lambda: not background_task.app_instance.background_tasks  # type: ignore
+    )

+ 23 - 14
integration/test_client_storage.py

@@ -133,7 +133,6 @@ def driver(client_side: AppHarness) -> Generator[WebDriver, None, None]:
     assert client_side.app_instance is not None, "app is not running"
     assert client_side.app_instance is not None, "app is not running"
     driver = client_side.frontend()
     driver = client_side.frontend()
     try:
     try:
-        assert client_side.poll_for_clients()
         yield driver
         yield driver
     finally:
     finally:
         driver.quit()
         driver.quit()
@@ -168,7 +167,20 @@ def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]:
     driver.delete_all_cookies()
     driver.delete_all_cookies()
 
 
 
 
-def test_client_side_state(
+def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]:
+    """Get a map of cookie names to cookie info.
+
+    Args:
+        driver: WebDriver instance.
+
+    Returns:
+        A map of cookie names to cookie info.
+    """
+    return {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
+
+
+@pytest.mark.asyncio
+async def test_client_side_state(
     client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage
     client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage
 ):
 ):
     """Test client side state.
     """Test client side state.
@@ -187,8 +199,6 @@ def test_client_side_state(
     token = client_side.poll_for_value(token_input)
     token = client_side.poll_for_value(token_input)
     assert token is not None
     assert token is not None
 
 
-    backend_state = client_side.app_instance.state_manager.states[token]
-
     # get a reference to the cookie manipulation form
     # get a reference to the cookie manipulation form
     state_var_input = driver.find_element(By.ID, "state_var")
     state_var_input = driver.find_element(By.ID, "state_var")
     input_value_input = driver.find_element(By.ID, "input_value")
     input_value_input = driver.find_element(By.ID, "input_value")
@@ -274,7 +284,7 @@ def test_client_side_state(
     input_value_input.send_keys("l1s value")
     input_value_input.send_keys("l1s value")
     set_sub_sub_state_button.click()
     set_sub_sub_state_button.click()
 
 
-    cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
+    cookies = cookie_info_map(driver)
     assert cookies.pop("client_side_state.client_side_sub_state.c1") == {
     assert cookies.pop("client_side_state.client_side_sub_state.c1") == {
         "domain": "localhost",
         "domain": "localhost",
         "httpOnly": False,
         "httpOnly": False,
@@ -338,8 +348,10 @@ def test_client_side_state(
     state_var_input.send_keys("c3")
     state_var_input.send_keys("c3")
     input_value_input.send_keys("c3 value")
     input_value_input.send_keys("c3 value")
     set_sub_state_button.click()
     set_sub_state_button.click()
-    cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
-    c3_cookie = cookies["client_side_state.client_side_sub_state.c3"]
+    AppHarness._poll_for(
+        lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
+    )
+    c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
     assert c3_cookie.pop("expiry") is not None
     assert c3_cookie.pop("expiry") is not None
     assert c3_cookie == {
     assert c3_cookie == {
         "domain": "localhost",
         "domain": "localhost",
@@ -351,9 +363,7 @@ def test_client_side_state(
         "value": "c3%20value",
         "value": "c3%20value",
     }
     }
     time.sleep(2)  # wait for c3 to expire
     time.sleep(2)  # wait for c3 to expire
-    assert "client_side_state.client_side_sub_state.c3" not in {
-        cookie_info["name"] for cookie_info in driver.get_cookies()
-    }
+    assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
 
 
     local_storage_items = local_storage.items()
     local_storage_items = local_storage.items()
     local_storage_items.pop("chakra-ui-color-mode", None)
     local_storage_items.pop("chakra-ui-color-mode", None)
@@ -426,7 +436,8 @@ def test_client_side_state(
     assert l1s.text == "l1s value"
     assert l1s.text == "l1s value"
 
 
     # reset the backend state to force refresh from client storage
     # reset the backend state to force refresh from client storage
-    backend_state.reset()
+    async with client_side.modify_state(token) as state:
+        state.reset()
     driver.refresh()
     driver.refresh()
 
 
     # wait for the backend connection to send the token (again)
     # wait for the backend connection to send the token (again)
@@ -465,9 +476,7 @@ def test_client_side_state(
     assert l1s.text == "l1s value"
     assert l1s.text == "l1s value"
 
 
     # make sure c5 cookie shows up on the `/foo` route
     # make sure c5 cookie shows up on the `/foo` route
-    cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
-
-    assert cookies["client_side_state.client_side_sub_state.c5"] == {
+    assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
         "domain": "localhost",
         "domain": "localhost",
         "httpOnly": False,
         "httpOnly": False,
         "name": "client_side_state.client_side_sub_state.c5",
         "name": "client_side_state.client_side_sub_state.c5",

+ 37 - 32
integration/test_dynamic_routes.py

@@ -1,11 +1,10 @@
 """Integration tests for dynamic route page behavior."""
 """Integration tests for dynamic route page behavior."""
-from typing import Callable, Generator, Type
+from typing import Callable, Coroutine, Generator, Type
 from urllib.parse import urlsplit
 from urllib.parse import urlsplit
 
 
 import pytest
 import pytest
 from selenium.webdriver.common.by import By
 from selenium.webdriver.common.by import By
 
 
-from reflex import State
 from reflex.testing import AppHarness, AppHarnessProd, WebDriver
 from reflex.testing import AppHarness, AppHarnessProd, WebDriver
 
 
 from .utils import poll_for_navigation
 from .utils import poll_for_navigation
@@ -100,22 +99,21 @@ def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]:
     assert dynamic_route.app_instance is not None, "app is not running"
     assert dynamic_route.app_instance is not None, "app is not running"
     driver = dynamic_route.frontend()
     driver = dynamic_route.frontend()
     try:
     try:
-        assert dynamic_route.poll_for_clients()
         yield driver
         yield driver
     finally:
     finally:
         driver.quit()
         driver.quit()
 
 
 
 
 @pytest.fixture()
 @pytest.fixture()
-def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
-    """Get the backend state.
+def token(dynamic_route: AppHarness, driver: WebDriver) -> str:
+    """Get the token associated with backend state.
 
 
     Args:
     Args:
         dynamic_route: harness for DynamicRoute app.
         dynamic_route: harness for DynamicRoute app.
         driver: WebDriver instance.
         driver: WebDriver instance.
 
 
     Returns:
     Returns:
-        The backend state associated with the token visible in the driver browser.
+        The token visible in the driver browser.
     """
     """
     assert dynamic_route.app_instance is not None
     assert dynamic_route.app_instance is not None
     token_input = driver.find_element(By.ID, "token")
     token_input = driver.find_element(By.ID, "token")
@@ -125,43 +123,49 @@ def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
     token = dynamic_route.poll_for_value(token_input)
     token = dynamic_route.poll_for_value(token_input)
     assert token is not None
     assert token is not None
 
 
-    # look up the backend state from the state manager
-    return dynamic_route.app_instance.state_manager.states[token]
+    return token
 
 
 
 
 @pytest.fixture()
 @pytest.fixture()
 def poll_for_order(
 def poll_for_order(
-    dynamic_route: AppHarness, backend_state: State
-) -> Callable[[list[str]], None]:
+    dynamic_route: AppHarness, token: str
+) -> Callable[[list[str]], Coroutine[None, None, None]]:
     """Poll for the order list to match the expected order.
     """Poll for the order list to match the expected order.
 
 
     Args:
     Args:
         dynamic_route: harness for DynamicRoute app.
         dynamic_route: harness for DynamicRoute app.
-        backend_state: The backend state associated with the token visible in the driver browser.
+        token: The token visible in the driver browser.
 
 
     Returns:
     Returns:
-        A function that polls for the order list to match the expected order.
+        An async function that polls for the order list to match the expected order.
     """
     """
 
 
-    def _poll_for_order(exp_order: list[str]):
-        dynamic_route._poll_for(lambda: backend_state.order == exp_order)
-        assert backend_state.order == exp_order
+    async def _poll_for_order(exp_order: list[str]):
+        async def _backend_state():
+            return await dynamic_route.get_state(token)
+
+        async def _check():
+            return (await _backend_state()).order == exp_order
+
+        await AppHarness._poll_for_async(_check)
+        assert (await _backend_state()).order == exp_order
 
 
     return _poll_for_order
     return _poll_for_order
 
 
 
 
-def test_on_load_navigate(
+@pytest.mark.asyncio
+async def test_on_load_navigate(
     dynamic_route: AppHarness,
     dynamic_route: AppHarness,
     driver: WebDriver,
     driver: WebDriver,
-    backend_state: State,
-    poll_for_order: Callable[[list[str]], None],
+    token: str,
+    poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
 ):
 ):
     """Click links to navigate between dynamic pages with on_load event.
     """Click links to navigate between dynamic pages with on_load event.
 
 
     Args:
     Args:
         dynamic_route: harness for DynamicRoute app.
         dynamic_route: harness for DynamicRoute app.
         driver: WebDriver instance.
         driver: WebDriver instance.
-        backend_state: The backend state associated with the token visible in the driver browser.
+        token: The token visible in the driver browser.
         poll_for_order: function that polls for the order list to match the expected order.
         poll_for_order: function that polls for the order list to match the expected order.
     """
     """
     assert dynamic_route.app_instance is not None
     assert dynamic_route.app_instance is not None
@@ -184,7 +188,7 @@ def test_on_load_navigate(
         assert page_id_input
         assert page_id_input
 
 
         assert dynamic_route.poll_for_value(page_id_input) == str(ix)
         assert dynamic_route.poll_for_value(page_id_input) == str(ix)
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
 
     # manually load the next page to trigger client side routing in prod mode
     # manually load the next page to trigger client side routing in prod mode
     if is_prod:
     if is_prod:
@@ -192,14 +196,14 @@ def test_on_load_navigate(
     exp_order += ["/page/[page_id]-10"]
     exp_order += ["/page/[page_id]-10"]
     with poll_for_navigation(driver):
     with poll_for_navigation(driver):
         driver.get(f"{dynamic_route.frontend_url}/page/10/")
         driver.get(f"{dynamic_route.frontend_url}/page/10/")
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
 
     # make sure internal nav still hydrates after redirect
     # make sure internal nav still hydrates after redirect
     exp_order += ["/page/[page_id]-11"]
     exp_order += ["/page/[page_id]-11"]
     link = driver.find_element(By.ID, "link_page_next")
     link = driver.find_element(By.ID, "link_page_next")
     with poll_for_navigation(driver):
     with poll_for_navigation(driver):
         link.click()
         link.click()
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
 
     # load same page with a query param and make sure it passes through
     # load same page with a query param and make sure it passes through
     if is_prod:
     if is_prod:
@@ -207,14 +211,14 @@ def test_on_load_navigate(
     exp_order += ["/page/[page_id]-11"]
     exp_order += ["/page/[page_id]-11"]
     with poll_for_navigation(driver):
     with poll_for_navigation(driver):
         driver.get(f"{driver.current_url}?foo=bar")
         driver.get(f"{driver.current_url}?foo=bar")
-    poll_for_order(exp_order)
-    assert backend_state.get_query_params()["foo"] == "bar"
+    await poll_for_order(exp_order)
+    assert (await dynamic_route.get_state(token)).get_query_params()["foo"] == "bar"
 
 
     # hit a 404 and ensure we still hydrate
     # hit a 404 and ensure we still hydrate
     exp_order += ["/404-no page id"]
     exp_order += ["/404-no page id"]
     with poll_for_navigation(driver):
     with poll_for_navigation(driver):
         driver.get(f"{dynamic_route.frontend_url}/missing")
         driver.get(f"{dynamic_route.frontend_url}/missing")
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
 
     # browser nav should still trigger hydration
     # browser nav should still trigger hydration
     if is_prod:
     if is_prod:
@@ -222,14 +226,14 @@ def test_on_load_navigate(
     exp_order += ["/page/[page_id]-11"]
     exp_order += ["/page/[page_id]-11"]
     with poll_for_navigation(driver):
     with poll_for_navigation(driver):
         driver.back()
         driver.back()
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
 
     # next/link to a 404 and ensure we still hydrate
     # next/link to a 404 and ensure we still hydrate
     exp_order += ["/404-no page id"]
     exp_order += ["/404-no page id"]
     link = driver.find_element(By.ID, "link_missing")
     link = driver.find_element(By.ID, "link_missing")
     with poll_for_navigation(driver):
     with poll_for_navigation(driver):
         link.click()
         link.click()
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
 
 
     # hit a page that redirects back to dynamic page
     # hit a page that redirects back to dynamic page
     if is_prod:
     if is_prod:
@@ -237,15 +241,16 @@ def test_on_load_navigate(
     exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"]
     exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"]
     with poll_for_navigation(driver):
     with poll_for_navigation(driver):
         driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar")
         driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar")
-    poll_for_order(exp_order)
+    await poll_for_order(exp_order)
     # should have redirected back to page 0
     # should have redirected back to page 0
     assert urlsplit(driver.current_url).path == "/page/0/"
     assert urlsplit(driver.current_url).path == "/page/0/"
 
 
 
 
-def test_on_load_navigate_non_dynamic(
+@pytest.mark.asyncio
+async def test_on_load_navigate_non_dynamic(
     dynamic_route: AppHarness,
     dynamic_route: AppHarness,
     driver: WebDriver,
     driver: WebDriver,
-    poll_for_order: Callable[[list[str]], None],
+    poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
 ):
 ):
     """Click links to navigate between static pages with on_load event.
     """Click links to navigate between static pages with on_load event.
 
 
@@ -261,7 +266,7 @@ def test_on_load_navigate_non_dynamic(
     with poll_for_navigation(driver):
     with poll_for_navigation(driver):
         link.click()
         link.click()
     assert urlsplit(driver.current_url).path == "/static/x/"
     assert urlsplit(driver.current_url).path == "/static/x/"
-    poll_for_order(["/static/x-no page id"])
+    await poll_for_order(["/static/x-no page id"])
 
 
     # go back to the index and navigate back to the static route
     # go back to the index and navigate back to the static route
     link = driver.find_element(By.ID, "link_index")
     link = driver.find_element(By.ID, "link_index")
@@ -273,4 +278,4 @@ def test_on_load_navigate_non_dynamic(
     with poll_for_navigation(driver):
     with poll_for_navigation(driver):
         link.click()
         link.click()
     assert urlsplit(driver.current_url).path == "/static/x/"
     assert urlsplit(driver.current_url).path == "/static/x/"
-    poll_for_order(["/static/x-no page id", "/static/x-no page id"])
+    await poll_for_order(["/static/x-no page id", "/static/x-no page id"])

+ 110 - 22
integration/test_event_chain.py

@@ -1,18 +1,20 @@
 """Ensure that Event Chains are properly queued and handled between frontend and backend."""
 """Ensure that Event Chains are properly queued and handled between frontend and backend."""
 
 
-import time
 from typing import Generator
 from typing import Generator
 
 
 import pytest
 import pytest
 from selenium.webdriver.common.by import By
 from selenium.webdriver.common.by import By
 
 
-from reflex.testing import AppHarness
+from reflex.testing import AppHarness, WebDriver
 
 
 MANY_EVENTS = 50
 MANY_EVENTS = 50
 
 
 
 
 def EventChain():
 def EventChain():
     """App with chained event handlers."""
     """App with chained event handlers."""
+    import asyncio
+    import time
+
     import reflex as rx
     import reflex as rx
 
 
     # repeated here since the outer global isn't exported into the App module
     # repeated here since the outer global isn't exported into the App module
@@ -20,6 +22,7 @@ def EventChain():
 
 
     class State(rx.State):
     class State(rx.State):
         event_order: list[str] = []
         event_order: list[str] = []
+        interim_value: str = ""
 
 
         @rx.var
         @rx.var
         def token(self) -> str:
         def token(self) -> str:
@@ -111,12 +114,25 @@ def EventChain():
             self.event_order.append("click_return_dict_type")
             self.event_order.append("click_return_dict_type")
             return State.event_arg_repr_type({"a": 1})  # type: ignore
             return State.event_arg_repr_type({"a": 1})  # type: ignore
 
 
+        async def click_yield_interim_value_async(self):
+            self.interim_value = "interim"
+            yield
+            await asyncio.sleep(0.5)
+            self.interim_value = "final"
+
+        def click_yield_interim_value(self):
+            self.interim_value = "interim"
+            yield
+            time.sleep(0.5)
+            self.interim_value = "final"
+
     app = rx.App(state=State)
     app = rx.App(state=State)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():
         return rx.fragment(
         return rx.fragment(
             rx.input(value=State.token, readonly=True, id="token"),
             rx.input(value=State.token, readonly=True, id="token"),
+            rx.input(value=State.interim_value, readonly=True, id="interim_value"),
             rx.button(
             rx.button(
                 "Return Event",
                 "Return Event",
                 id="return_event",
                 id="return_event",
@@ -172,6 +188,16 @@ def EventChain():
                 id="return_dict_type",
                 id="return_dict_type",
                 on_click=State.click_return_dict_type,
                 on_click=State.click_return_dict_type,
             ),
             ),
+            rx.button(
+                "Click Yield Interim Value (Async)",
+                id="click_yield_interim_value_async",
+                on_click=State.click_yield_interim_value_async,
+            ),
+            rx.button(
+                "Click Yield Interim Value",
+                id="click_yield_interim_value",
+                on_click=State.click_yield_interim_value,
+            ),
         )
         )
 
 
     def on_load_return_chain():
     def on_load_return_chain():
@@ -237,7 +263,7 @@ def event_chain(tmp_path_factory) -> Generator[AppHarness, None, None]:
 
 
 
 
 @pytest.fixture
 @pytest.fixture
-def driver(event_chain: AppHarness):
+def driver(event_chain: AppHarness) -> Generator[WebDriver, None, None]:
     """Get an instance of the browser open to the event_chain app.
     """Get an instance of the browser open to the event_chain app.
 
 
     Args:
     Args:
@@ -249,7 +275,6 @@ def driver(event_chain: AppHarness):
     assert event_chain.app_instance is not None, "app is not running"
     assert event_chain.app_instance is not None, "app is not running"
     driver = event_chain.frontend()
     driver = event_chain.frontend()
     try:
     try:
-        assert event_chain.poll_for_clients()
         yield driver
         yield driver
     finally:
     finally:
         driver.quit()
         driver.quit()
@@ -335,7 +360,13 @@ def driver(event_chain: AppHarness):
         ),
         ),
     ],
     ],
 )
 )
-def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
+@pytest.mark.asyncio
+async def test_event_chain_click(
+    event_chain: AppHarness,
+    driver: WebDriver,
+    button_id: str,
+    exp_event_order: list[str],
+):
     """Click the button, assert that the events are handled in the correct order.
     """Click the button, assert that the events are handled in the correct order.
 
 
     Args:
     Args:
@@ -350,17 +381,18 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
     assert btn
     assert btn
 
 
     token = event_chain.poll_for_value(token_input)
     token = event_chain.poll_for_value(token_input)
+    assert token is not None
 
 
     btn.click()
     btn.click()
-    if "redirect" in button_id:
-        # wait a bit longer if we're redirecting
-        time.sleep(1)
-    if "many_events" in button_id:
-        # wait a bit longer if we have loads of events
-        time.sleep(1)
-    time.sleep(0.5)
-    backend_state = event_chain.app_instance.state_manager.states[token]
-    assert backend_state.event_order == exp_event_order
+
+    async def _has_all_events():
+        return len((await event_chain.get_state(token)).event_order) == len(
+            exp_event_order
+        )
+
+    await AppHarness._poll_for_async(_has_all_events)
+    event_order = (await event_chain.get_state(token)).event_order
+    assert event_order == exp_event_order
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -386,7 +418,13 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
         ),
         ),
     ],
     ],
 )
 )
-def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
+@pytest.mark.asyncio
+async def test_event_chain_on_load(
+    event_chain: AppHarness,
+    driver: WebDriver,
+    uri: str,
+    exp_event_order: list[str],
+):
     """Load the URI, assert that the events are handled in the correct order.
     """Load the URI, assert that the events are handled in the correct order.
 
 
     Args:
     Args:
@@ -395,16 +433,23 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
         uri: the page to load
         uri: the page to load
         exp_event_order: the expected events recorded in the State
         exp_event_order: the expected events recorded in the State
     """
     """
+    assert event_chain.frontend_url is not None
     driver.get(event_chain.frontend_url + uri)
     driver.get(event_chain.frontend_url + uri)
     token_input = driver.find_element(By.ID, "token")
     token_input = driver.find_element(By.ID, "token")
     assert token_input
     assert token_input
 
 
     token = event_chain.poll_for_value(token_input)
     token = event_chain.poll_for_value(token_input)
+    assert token is not None
 
 
-    time.sleep(0.5)
-    backend_state = event_chain.app_instance.state_manager.states[token]
-    assert backend_state.is_hydrated is True
+    async def _has_all_events():
+        return len((await event_chain.get_state(token)).event_order) == len(
+            exp_event_order
+        )
+
+    await AppHarness._poll_for_async(_has_all_events)
+    backend_state = await event_chain.get_state(token)
     assert backend_state.event_order == exp_event_order
     assert backend_state.event_order == exp_event_order
+    assert backend_state.is_hydrated is True
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -444,7 +489,13 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
         ),
         ),
     ],
     ],
 )
 )
-def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
+@pytest.mark.asyncio
+async def test_event_chain_on_mount(
+    event_chain: AppHarness,
+    driver: WebDriver,
+    uri: str,
+    exp_event_order: list[str],
+):
     """Load the URI, assert that the events are handled in the correct order.
     """Load the URI, assert that the events are handled in the correct order.
 
 
     These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
     These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
@@ -458,16 +509,53 @@ def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
         uri: the page to load
         uri: the page to load
         exp_event_order: the expected events recorded in the State
         exp_event_order: the expected events recorded in the State
     """
     """
+    assert event_chain.frontend_url is not None
     driver.get(event_chain.frontend_url + uri)
     driver.get(event_chain.frontend_url + uri)
     token_input = driver.find_element(By.ID, "token")
     token_input = driver.find_element(By.ID, "token")
     assert token_input
     assert token_input
 
 
     token = event_chain.poll_for_value(token_input)
     token = event_chain.poll_for_value(token_input)
+    assert token is not None
 
 
     unmount_button = driver.find_element(By.ID, "unmount")
     unmount_button = driver.find_element(By.ID, "unmount")
     assert unmount_button
     assert unmount_button
     unmount_button.click()
     unmount_button.click()
 
 
-    time.sleep(1)
-    backend_state = event_chain.app_instance.state_manager.states[token]
-    assert backend_state.event_order == exp_event_order
+    async def _has_all_events():
+        return len((await event_chain.get_state(token)).event_order) == len(
+            exp_event_order
+        )
+
+    await AppHarness._poll_for_async(_has_all_events)
+    event_order = (await event_chain.get_state(token)).event_order
+    assert event_order == exp_event_order
+
+
+@pytest.mark.parametrize(
+    ("button_id",),
+    [
+        ("click_yield_interim_value_async",),
+        ("click_yield_interim_value",),
+    ],
+)
+def test_yield_state_update(event_chain: AppHarness, driver: WebDriver, button_id: str):
+    """Click the button, assert that the interim value is set, then final value is set.
+
+    Args:
+        event_chain: AppHarness for the event_chain app
+        driver: selenium WebDriver open to the app
+        button_id: the ID of the button to click
+    """
+    token_input = driver.find_element(By.ID, "token")
+    interim_value_input = driver.find_element(By.ID, "interim_value")
+    assert event_chain.poll_for_value(token_input)
+
+    btn = driver.find_element(By.ID, button_id)
+    btn.click()
+    assert (
+        event_chain.poll_for_value(interim_value_input, exp_not_equal="") == "interim"
+    )
+    assert (
+        event_chain.poll_for_value(interim_value_input, exp_not_equal="interim")
+        == "final"
+    )

+ 32 - 18
integration/test_form_submit.py

@@ -19,11 +19,16 @@ def FormSubmit():
         def form_submit(self, form_data: dict):
         def form_submit(self, form_data: dict):
             self.form_data = form_data
             self.form_data = form_data
 
 
+        @rx.var
+        def token(self) -> str:
+            return self.get_token()
+
     app = rx.App(state=FormState)
     app = rx.App(state=FormState)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():
         return rx.vstack(
         return rx.vstack(
+            rx.input(value=FormState.token, is_read_only=True, id="token"),
             rx.form(
             rx.form(
                 rx.vstack(
                 rx.vstack(
                     rx.input(id="name_input"),
                     rx.input(id="name_input"),
@@ -82,13 +87,13 @@ def driver(form_submit: AppHarness):
     """
     """
     driver = form_submit.frontend()
     driver = form_submit.frontend()
     try:
     try:
-        assert form_submit.poll_for_clients()
         yield driver
         yield driver
     finally:
     finally:
         driver.quit()
         driver.quit()
 
 
 
 
-def test_submit(driver, form_submit: AppHarness):
+@pytest.mark.asyncio
+async def test_submit(driver, form_submit: AppHarness):
     """Fill a form with various different output, submit it to backend and verify
     """Fill a form with various different output, submit it to backend and verify
     the output.
     the output.
 
 
@@ -97,7 +102,14 @@ def test_submit(driver, form_submit: AppHarness):
         form_submit: harness for FormSubmit app
         form_submit: harness for FormSubmit app
     """
     """
     assert form_submit.app_instance is not None, "app is not running"
     assert form_submit.app_instance is not None, "app is not running"
-    _, backend_state = list(form_submit.app_instance.state_manager.states.items())[0]
+
+    # get a reference to the connected client
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = form_submit.poll_for_value(token_input)
+    assert token
 
 
     name_input = driver.find_element(By.ID, "name_input")
     name_input = driver.find_element(By.ID, "name_input")
     name_input.send_keys("foo")
     name_input.send_keys("foo")
@@ -132,19 +144,21 @@ def test_submit(driver, form_submit: AppHarness):
     submit_input = driver.find_element(By.CLASS_NAME, "chakra-button")
     submit_input = driver.find_element(By.CLASS_NAME, "chakra-button")
     submit_input.click()
     submit_input.click()
 
 
+    async def get_form_data():
+        return (await form_submit.get_state(token)).form_data
+
     # wait for the form data to arrive at the backend
     # wait for the form data to arrive at the backend
-    AppHarness._poll_for(
-        lambda: backend_state.form_data != {},
-    )
-
-    assert backend_state.form_data["name_input"] == "foo"
-    assert backend_state.form_data["pin_input"] == pin_values
-    assert backend_state.form_data["number_input"] == "-3"
-    assert backend_state.form_data["bool_input"] is True
-    assert backend_state.form_data["bool_input2"] is True
-    assert backend_state.form_data["slider_input"] == "50"
-    assert backend_state.form_data["range_input"] == ["25", "75"]
-    assert backend_state.form_data["radio_input"] == "option2"
-    assert backend_state.form_data["select_input"] == "option1"
-    assert backend_state.form_data["text_area_input"] == "Some\nText"
-    assert backend_state.form_data["debounce_input"] == "bar baz"
+    form_data = await AppHarness._poll_for_async(get_form_data)
+    assert isinstance(form_data, dict)
+
+    assert form_data["name_input"] == "foo"
+    assert form_data["pin_input"] == pin_values
+    assert form_data["number_input"] == "-3"
+    assert form_data["bool_input"] is True
+    assert form_data["bool_input2"] is True
+    assert form_data["slider_input"] == "50"
+    assert form_data["range_input"] == ["25", "75"]
+    assert form_data["radio_input"] == "option2"
+    assert form_data["select_input"] == "option1"
+    assert form_data["text_area_input"] == "Some\nText"
+    assert form_data["debounce_input"] == "bar baz"

+ 19 - 11
integration/test_input.py

@@ -16,11 +16,16 @@ def FullyControlledInput():
     class State(rx.State):
     class State(rx.State):
         text: str = "initial"
         text: str = "initial"
 
 
+        @rx.var
+        def token(self) -> str:
+            return self.get_token()
+
     app = rx.App(state=State)
     app = rx.App(state=State)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():
         return rx.fragment(
         return rx.fragment(
+            rx.input(value=State.token, is_read_only=True, id="token"),
             rx.input(
             rx.input(
                 id="debounce_input_input",
                 id="debounce_input_input",
                 on_change=State.set_text,  # type: ignore
                 on_change=State.set_text,  # type: ignore
@@ -62,10 +67,12 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     driver = fully_controlled_input.frontend()
     driver = fully_controlled_input.frontend()
 
 
     # get a reference to the connected client
     # get a reference to the connected client
-    assert len(fully_controlled_input.poll_for_clients()) == 1
-    token, backend_state = list(
-        fully_controlled_input.app_instance.state_manager.states.items()
-    )[0]
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = fully_controlled_input.poll_for_value(token_input)
+    assert token
 
 
     # find the input and wait for it to have the initial state value
     # find the input and wait for it to have the initial state value
     debounce_input = driver.find_element(By.ID, "debounce_input_input")
     debounce_input = driver.find_element(By.ID, "debounce_input_input")
@@ -80,14 +87,13 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("foo")
     debounce_input.send_keys("foo")
     time.sleep(0.5)
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "ifoonitial"
     assert debounce_input.get_attribute("value") == "ifoonitial"
-    assert backend_state.text == "ifoonitial"
+    assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
     assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
     assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
 
 
     # clear the input on the backend
     # clear the input on the backend
-    backend_state.text = ""
-    fully_controlled_input.app_instance.state_manager.set_state(token, backend_state)
-    await fully_controlled_input.emit_state_updates()
-    assert backend_state.text == ""
+    async with fully_controlled_input.modify_state(token) as state:
+        state.text = ""
+    assert (await fully_controlled_input.get_state(token)).text == ""
     assert (
     assert (
         fully_controlled_input.poll_for_value(
         fully_controlled_input.poll_for_value(
             debounce_input, exp_not_equal="ifoonitial"
             debounce_input, exp_not_equal="ifoonitial"
@@ -99,7 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("getting testing done")
     debounce_input.send_keys("getting testing done")
     time.sleep(0.5)
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "getting testing done"
     assert debounce_input.get_attribute("value") == "getting testing done"
-    assert backend_state.text == "getting testing done"
+    assert (
+        await fully_controlled_input.get_state(token)
+    ).text == "getting testing done"
     assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
     assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
 
 
     # type into the on_change input
     # type into the on_change input
@@ -107,7 +115,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     time.sleep(0.5)
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "overwrite the state"
     assert debounce_input.get_attribute("value") == "overwrite the state"
     assert on_change_input.get_attribute("value") == "overwrite the state"
     assert on_change_input.get_attribute("value") == "overwrite the state"
-    assert backend_state.text == "overwrite the state"
+    assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
     assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
     assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
 
 
     clear_button.click()
     clear_button.click()

+ 11 - 1
integration/test_server_side_event.py

@@ -33,11 +33,16 @@ def ServerSideEvent():
         def set_value_return_c(self):
         def set_value_return_c(self):
             return rx.set_value("c", "")
             return rx.set_value("c", "")
 
 
+        @rx.var
+        def token(self) -> str:
+            return self.get_token()
+
     app = rx.App(state=SSState)
     app = rx.App(state=SSState)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():
         return rx.fragment(
         return rx.fragment(
+            rx.input(id="token", value=SSState.token, is_read_only=True),
             rx.input(default_value="a", id="a"),
             rx.input(default_value="a", id="a"),
             rx.input(default_value="b", id="b"),
             rx.input(default_value="b", id="b"),
             rx.input(default_value="c", id="c"),
             rx.input(default_value="c", id="c"),
@@ -106,7 +111,12 @@ def driver(server_side_event: AppHarness):
     assert server_side_event.app_instance is not None, "app is not running"
     assert server_side_event.app_instance is not None, "app is not running"
     driver = server_side_event.frontend()
     driver = server_side_event.frontend()
     try:
     try:
-        assert server_side_event.poll_for_clients()
+        token_input = driver.find_element(By.ID, "token")
+        assert token_input
+        # wait for the backend connection to send the token
+        token = server_side_event.poll_for_value(token_input)
+        assert token is not None
+
         yield driver
         yield driver
     finally:
     finally:
         driver.quit()
         driver.quit()

+ 16 - 9
integration/test_upload.py

@@ -89,13 +89,13 @@ def driver(upload_file: AppHarness):
     assert upload_file.app_instance is not None, "app is not running"
     assert upload_file.app_instance is not None, "app is not running"
     driver = upload_file.frontend()
     driver = upload_file.frontend()
     try:
     try:
-        assert upload_file.poll_for_clients()
         yield driver
         yield driver
     finally:
     finally:
         driver.quit()
         driver.quit()
 
 
 
 
-def test_upload_file(tmp_path, upload_file: AppHarness, driver):
+@pytest.mark.asyncio
+async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
     """Submit a file upload and check that it arrived on the backend.
     """Submit a file upload and check that it arrived on the backend.
 
 
     Args:
     Args:
@@ -124,16 +124,20 @@ def test_upload_file(tmp_path, upload_file: AppHarness, driver):
     upload_button.click()
     upload_button.click()
 
 
     # look up the backend state and assert on uploaded contents
     # look up the backend state and assert on uploaded contents
-    backend_state = upload_file.app_instance.state_manager.states[token]
-    time.sleep(0.5)
-    assert backend_state._file_data[exp_name] == exp_contents
+    async def get_file_data():
+        return (await upload_file.get_state(token))._file_data
+
+    file_data = await AppHarness._poll_for_async(get_file_data)
+    assert isinstance(file_data, dict)
+    assert file_data[exp_name] == exp_contents
 
 
     # check that the selected files are displayed
     # check that the selected files are displayed
     selected_files = driver.find_element(By.ID, "selected_files")
     selected_files = driver.find_element(By.ID, "selected_files")
     assert selected_files.text == exp_name
     assert selected_files.text == exp_name
 
 
 
 
-def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
+@pytest.mark.asyncio
+async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
     """Submit several file uploads and check that they arrived on the backend.
     """Submit several file uploads and check that they arrived on the backend.
 
 
     Args:
     Args:
@@ -173,10 +177,13 @@ def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
     upload_button.click()
     upload_button.click()
 
 
     # look up the backend state and assert on uploaded contents
     # look up the backend state and assert on uploaded contents
-    backend_state = upload_file.app_instance.state_manager.states[token]
-    time.sleep(0.5)
+    async def get_file_data():
+        return (await upload_file.get_state(token))._file_data
+
+    file_data = await AppHarness._poll_for_async(get_file_data)
+    assert isinstance(file_data, dict)
     for exp_name, exp_contents in exp_files.items():
     for exp_name, exp_contents in exp_files.items():
-        assert backend_state._file_data[exp_name] == exp_contents
+        assert file_data[exp_name] == exp_contents
 
 
 
 
 def test_clear_files(tmp_path, upload_file: AppHarness, driver):
 def test_clear_files(tmp_path, upload_file: AppHarness, driver):

+ 11 - 1
integration/test_var_operations.py

@@ -26,11 +26,16 @@ def VarOperations():
         dict1: dict = {1: 2}
         dict1: dict = {1: 2}
         dict2: dict = {3: 4}
         dict2: dict = {3: 4}
 
 
+        @rx.var
+        def token(self) -> str:
+            return self.get_token()
+
     app = rx.App(state=VarOperationState)
     app = rx.App(state=VarOperationState)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():
         return rx.vstack(
         return rx.vstack(
+            rx.input(id="token", value=VarOperationState.token, is_read_only=True),
             # INT INT
             # INT INT
             rx.text(
             rx.text(
                 VarOperationState.int_var1 + VarOperationState.int_var2,
                 VarOperationState.int_var1 + VarOperationState.int_var2,
@@ -544,7 +549,12 @@ def driver(var_operations: AppHarness):
     """
     """
     driver = var_operations.frontend()
     driver = var_operations.frontend()
     try:
     try:
-        assert var_operations.poll_for_clients()
+        token_input = driver.find_element(By.ID, "token")
+        assert token_input
+        # wait for the backend connection to send the token
+        token = var_operations.poll_for_value(token_input)
+        assert token is not None
+
         yield driver
         yield driver
     finally:
     finally:
         driver.quit()
         driver.quit()

+ 1 - 0
reflex/__init__.py

@@ -21,6 +21,7 @@ from .constants import Env as Env
 from .event import EVENT_ARG as EVENT_ARG
 from .event import EVENT_ARG as EVENT_ARG
 from .event import EventChain as EventChain
 from .event import EventChain as EventChain
 from .event import FileUpload as upload_files
 from .event import FileUpload as upload_files
+from .event import background as background
 from .event import clear_local_storage as clear_local_storage
 from .event import clear_local_storage as clear_local_storage
 from .event import console_log as console_log
 from .event import console_log as console_log
 from .event import download as download
 from .event import download as download

+ 164 - 73
reflex/app.py

@@ -2,6 +2,7 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
+import contextlib
 import inspect
 import inspect
 import os
 import os
 from multiprocessing.pool import ThreadPool
 from multiprocessing.pool import ThreadPool
@@ -13,6 +14,7 @@ from typing import (
     Dict,
     Dict,
     List,
     List,
     Optional,
     Optional,
+    Set,
     Type,
     Type,
     Union,
     Union,
 )
 )
@@ -49,7 +51,13 @@ from reflex.route import (
     get_route_args,
     get_route_args,
     verify_route_validity,
     verify_route_validity,
 )
 )
-from reflex.state import DefaultState, State, StateManager, StateUpdate
+from reflex.state import (
+    DefaultState,
+    State,
+    StateManager,
+    StateManagerMemory,
+    StateUpdate,
+)
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils import console, format, prerequisites, types
 from reflex.vars import ImportVar
 from reflex.vars import ImportVar
 
 
@@ -89,7 +97,7 @@ class App(Base):
     state: Type[State] = DefaultState
     state: Type[State] = DefaultState
 
 
     # Class to manage many client states.
     # Class to manage many client states.
-    state_manager: StateManager = StateManager()
+    state_manager: StateManager = StateManagerMemory(state=DefaultState)
 
 
     # The styling to apply to each component.
     # The styling to apply to each component.
     style: ComponentStyle = {}
     style: ComponentStyle = {}
@@ -104,13 +112,16 @@ class App(Base):
     admin_dash: Optional[AdminDash] = None
     admin_dash: Optional[AdminDash] = None
 
 
     # The async server name space
     # The async server name space
-    event_namespace: Optional[AsyncNamespace] = None
+    event_namespace: Optional[EventNamespace] = None
 
 
     # A component that is present on every page.
     # A component that is present on every page.
     overlay_component: Optional[
     overlay_component: Optional[
         Union[Component, ComponentCallable]
         Union[Component, ComponentCallable]
     ] = default_overlay_component
     ] = default_overlay_component
 
 
+    # Background tasks that are currently running
+    background_tasks: Set[asyncio.Task] = set()
+
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         """Initialize the app.
         """Initialize the app.
 
 
@@ -154,7 +165,7 @@ class App(Base):
         self.middleware.append(HydrateMiddleware())
         self.middleware.append(HydrateMiddleware())
 
 
         # Set up the state manager.
         # Set up the state manager.
-        self.state_manager.setup(state=self.state)
+        self.state_manager = StateManager.create(state=self.state)
 
 
         # Set up the API.
         # Set up the API.
         self.api = FastAPI()
         self.api = FastAPI()
@@ -646,6 +657,76 @@ class App(Base):
         thread_pool.close()
         thread_pool.close()
         thread_pool.join()
         thread_pool.join()
 
 
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state out of band.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state to modify.
+
+        Raises:
+            RuntimeError: If the app has not been initialized yet.
+        """
+        if self.event_namespace is None:
+            raise RuntimeError("App has not been initialized yet.")
+        # Get exclusive access to the state.
+        async with self.state_manager.modify_state(token) as state:
+            # No other event handler can modify the state while in this context.
+            yield state
+            delta = state.get_delta()
+            if delta:
+                # When the state is modified reset dirty status and emit the delta to the frontend.
+                state._clean()
+                await self.event_namespace.emit_update(
+                    update=StateUpdate(delta=delta),
+                    sid=state.get_sid(),
+                )
+
+    def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
+        """Process an event in the background and emit updates as they arrive.
+
+        Args:
+            state: The state to process the event for.
+            event: The event to process.
+
+        Returns:
+            Task if the event was backgroundable, otherwise None
+        """
+        substate, handler = state._get_event_handler(event)
+        if not handler.is_background:
+            return None
+
+        async def _coro():
+            """Coroutine to process the event and emit updates inside an asyncio.Task.
+
+            Raises:
+                RuntimeError: If the app has not been initialized yet.
+            """
+            if self.event_namespace is None:
+                raise RuntimeError("App has not been initialized yet.")
+
+            # Process the event.
+            async for update in state._process_event(
+                handler=handler, state=substate, payload=event.payload
+            ):
+                # Postprocess the event.
+                update = await self.postprocess(state, event, update)
+
+                # Send the update to the client.
+                await self.event_namespace.emit_update(
+                    update=update,
+                    sid=state.get_sid(),
+                )
+
+        task = asyncio.create_task(_coro())
+        self.background_tasks.add(task)
+        # Clean up task from background_tasks set when complete.
+        task.add_done_callback(self.background_tasks.discard)
+        return task
+
 
 
 async def process(
 async def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
@@ -662,9 +743,6 @@ async def process(
     Yields:
     Yields:
         The state updates after processing the event.
         The state updates after processing the event.
     """
     """
-    # Get the state for the session.
-    state = app.state_manager.get_state(event.token)
-
     # Add request data to the state.
     # Add request data to the state.
     router_data = event.router_data
     router_data = event.router_data
     router_data.update(
     router_data.update(
@@ -676,31 +754,35 @@ async def process(
             constants.RouteVar.CLIENT_IP: client_ip,
             constants.RouteVar.CLIENT_IP: client_ip,
         }
         }
     )
     )
-    # re-assign only when the value is different
-    if state.router_data != router_data:
-        # assignment will recurse into substates and force recalculation of
-        # dependent ComputedVar (dynamic route variables)
-        state.router_data = router_data
-
-    # Preprocess the event.
-    update = await app.preprocess(state, event)
-
-    # If there was an update, yield it.
-    if update is not None:
-        yield update
-
-    # Only process the event if there is no update.
-    else:
-        # Process the event.
-        async for update in state._process(event):
-            # Postprocess the event.
-            update = await app.postprocess(state, event, update)
-
-            # Yield the update.
+    # Get the state for the session exclusively.
+    async with app.state_manager.modify_state(event.token) as state:
+        # re-assign only when the value is different
+        if state.router_data != router_data:
+            # assignment will recurse into substates and force recalculation of
+            # dependent ComputedVar (dynamic route variables)
+            state.router_data = router_data
+
+        # Preprocess the event.
+        update = await app.preprocess(state, event)
+
+        # If there was an update, yield it.
+        if update is not None:
             yield update
             yield update
 
 
-    # Set the state for the session.
-    app.state_manager.set_state(event.token, state)
+        # Only process the event if there is no update.
+        else:
+            if app._process_background(state, event) is not None:
+                # `final=True` allows the frontend send more events immediately.
+                yield StateUpdate(final=True)
+                return
+
+            # Process the event synchronously.
+            async for update in state._process(event):
+                # Postprocess the event.
+                update = await app.postprocess(state, event, update)
+
+                # Yield the update.
+                yield update
 
 
 
 
 async def ping() -> str:
 async def ping() -> str:
@@ -737,47 +819,46 @@ def upload(app: App):
             assert file.filename is not None
             assert file.filename is not None
             file.filename = file.filename.split(":")[-1]
             file.filename = file.filename.split(":")[-1]
         # Get the state for the session.
         # Get the state for the session.
-        state = app.state_manager.get_state(token)
-        # get the current session ID
-        sid = state.get_sid()
-        # get the current state(parent state/substate)
-        path = handler.split(".")[:-1]
-        current_state = state.get_substate(path)
-        handler_upload_param = ()
-
-        # get handler function
-        func = getattr(current_state, handler.split(".")[-1])
-
-        # check if there exists any handler args with annotation, List[UploadFile]
-        for k, v in inspect.getfullargspec(
-            func.fn if isinstance(func, EventHandler) else func
-        ).annotations.items():
-            if types.is_generic_alias(v) and types._issubclass(
-                v.__args__[0], UploadFile
-            ):
-                handler_upload_param = (k, v)
-                break
+        async with app.state_manager.modify_state(token) as state:
+            # get the current session ID
+            sid = state.get_sid()
+            # get the current state(parent state/substate)
+            path = handler.split(".")[:-1]
+            current_state = state.get_substate(path)
+            handler_upload_param = ()
+
+            # get handler function
+            func = getattr(current_state, handler.split(".")[-1])
+
+            # check if there exists any handler args with annotation, List[UploadFile]
+            for k, v in inspect.getfullargspec(
+                func.fn if isinstance(func, EventHandler) else func
+            ).annotations.items():
+                if types.is_generic_alias(v) and types._issubclass(
+                    v.__args__[0], UploadFile
+                ):
+                    handler_upload_param = (k, v)
+                    break
+
+            if not handler_upload_param:
+                raise ValueError(
+                    f"`{handler}` handler should have a parameter annotated as List["
+                    f"rx.UploadFile]"
+                )
 
 
-        if not handler_upload_param:
-            raise ValueError(
-                f"`{handler}` handler should have a parameter annotated as List["
-                f"rx.UploadFile]"
+            event = Event(
+                token=token,
+                name=handler,
+                payload={handler_upload_param[0]: files},
             )
             )
-
-        event = Event(
-            token=token,
-            name=handler,
-            payload={handler_upload_param[0]: files},
-        )
-        async for update in state._process(event):
-            # Postprocess the event.
-            update = await app.postprocess(state, event, update)
-            # Send update to client
-            await asyncio.create_task(
-                app.event_namespace.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)  # type: ignore
-            )
-        # Set the state for the session.
-        app.state_manager.set_state(event.token, state)
+            async for update in state._process(event):
+                # Postprocess the event.
+                update = await app.postprocess(state, event, update)
+                # Send update to client
+                await app.event_namespace.emit_update(  # type: ignore
+                    update=update,
+                    sid=sid,
+                )
 
 
     return upload_file
     return upload_file
 
 
@@ -815,6 +896,18 @@ class EventNamespace(AsyncNamespace):
         """
         """
         pass
         pass
 
 
+    async def emit_update(self, update: StateUpdate, sid: str) -> None:
+        """Emit an update to the client.
+
+        Args:
+            update: The state update to send.
+            sid: The Socket.IO session id.
+        """
+        # Creating a task prevents the update from being blocked behind other coroutines.
+        await asyncio.create_task(
+            self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
+        )
+
     async def on_event(self, sid, data):
     async def on_event(self, sid, data):
         """Event for receiving front-end websocket events.
         """Event for receiving front-end websocket events.
 
 
@@ -841,10 +934,8 @@ class EventNamespace(AsyncNamespace):
 
 
         # Process the events.
         # Process the events.
         async for update in process(self.app, event, sid, headers, client_ip):
         async for update in process(self.app, event, sid, headers, client_ip):
-            # Emit the event.
-            await asyncio.create_task(
-                self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
-            )
+            # Emit the update from processing the event.
+            await self.emit_update(update=update, sid=sid)
 
 
     async def on_ping(self, sid):
     async def on_ping(self, sid):
         """Event for testing the API endpoint.
         """Event for testing the API endpoint.

+ 8 - 0
reflex/app.pyi

@@ -1,5 +1,6 @@
 """ Generated with stubgen from mypy, then manually edited, do not regen."""
 """ Generated with stubgen from mypy, then manually edited, do not regen."""
 
 
+import asyncio
 from fastapi import FastAPI
 from fastapi import FastAPI
 from fastapi import UploadFile as UploadFile
 from fastapi import UploadFile as UploadFile
 from reflex import constants as constants
 from reflex import constants as constants
@@ -45,12 +46,14 @@ from reflex.utils import (
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
 from typing import (
 from typing import (
     Any,
     Any,
+    AsyncContextManager,
     AsyncIterator,
     AsyncIterator,
     Callable,
     Callable,
     Coroutine,
     Coroutine,
     Dict,
     Dict,
     List,
     List,
     Optional,
     Optional,
+    Set,
     Type,
     Type,
     Union,
     Union,
     overload,
     overload,
@@ -75,6 +78,7 @@ class App(Base):
     admin_dash: Optional[AdminDash]
     admin_dash: Optional[AdminDash]
     event_namespace: Optional[AsyncNamespace]
     event_namespace: Optional[AsyncNamespace]
     overlay_component: Optional[Union[Component, ComponentCallable]]
     overlay_component: Optional[Union[Component, ComponentCallable]]
+    background_tasks: Set[asyncio.Task] = set()
     def __init__(
     def __init__(
         self,
         self,
         *args,
         *args,
@@ -116,6 +120,10 @@ class App(Base):
     def setup_admin_dash(self) -> None: ...
     def setup_admin_dash(self) -> None: ...
     def get_frontend_packages(self, imports: Dict[str, str]): ...
     def get_frontend_packages(self, imports: Dict[str, str]): ...
     def compile(self) -> None: ...
     def compile(self) -> None: ...
+    def modify_state(self, token: str) -> AsyncContextManager[State]: ...
+    def _process_background(
+        self, state: State, event: Event
+    ) -> asyncio.Task | None: ...
 
 
 async def process(
 async def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
     app: App, event: Event, sid: str, headers: Dict, client_ip: str

+ 2 - 0
reflex/constants.py

@@ -219,6 +219,8 @@ OLD_CONFIG_FILE = f"pcconfig{PY_EXT}"
 PRODUCTION_BACKEND_URL = "https://{username}-{app_name}.api.pynecone.app"
 PRODUCTION_BACKEND_URL = "https://{username}-{app_name}.api.pynecone.app"
 # Token expiration time in seconds.
 # Token expiration time in seconds.
 TOKEN_EXPIRATION = 60 * 60
 TOKEN_EXPIRATION = 60 * 60
+# Maximum time in milliseconds that a state can be locked for exclusive access.
+LOCK_EXPIRATION = 10000
 
 
 # Testing variables.
 # Testing variables.
 # Testing os env set by pytest when running a test case.
 # Testing os env set by pytest when running a test case.

+ 84 - 2
reflex/event.py

@@ -2,7 +2,17 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import inspect
 import inspect
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
 
 
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
@@ -10,6 +20,9 @@ from reflex.utils import console, format
 from reflex.utils.types import ArgsSpec
 from reflex.utils.types import ArgsSpec
 from reflex.vars import BaseVar, Var
 from reflex.vars import BaseVar, Var
 
 
+if TYPE_CHECKING:
+    from reflex.state import State
+
 
 
 class Event(Base):
 class Event(Base):
     """An event that describes any state change in the app."""
     """An event that describes any state change in the app."""
@@ -27,6 +40,66 @@ class Event(Base):
     payload: Dict[str, Any] = {}
     payload: Dict[str, Any] = {}
 
 
 
 
+BACKGROUND_TASK_MARKER = "_reflex_background_task"
+
+
+def background(fn):
+    """Decorator to mark event handler as running in the background.
+
+    Args:
+        fn: The function to decorate.
+
+    Returns:
+        The same function, but with a marker set.
+
+
+    Raises:
+        TypeError: If the function is not a coroutine function or async generator.
+    """
+    if not inspect.iscoroutinefunction(fn) and not inspect.isasyncgenfunction(fn):
+        raise TypeError("Background task must be async function or generator.")
+    setattr(fn, BACKGROUND_TASK_MARKER, True)
+    return fn
+
+
+def _no_chain_background_task(
+    state_cls: Type["State"], name: str, fn: Callable
+) -> Callable:
+    """Protect against directly chaining a background task from another event handler.
+
+    Args:
+        state_cls: The state class that the event handler is in.
+        name: The name of the background task.
+        fn: The background task coroutine function / generator.
+
+    Returns:
+        A compatible coroutine function / generator that raises a runtime error.
+
+    Raises:
+        TypeError: If the background task is not async.
+    """
+    call = f"{state_cls.__name__}.{name}"
+    message = (
+        f"Cannot directly call background task {name!r}, use "
+        f"`yield {call}` or `return {call}` instead."
+    )
+    if inspect.iscoroutinefunction(fn):
+
+        async def _no_chain_background_task_co(*args, **kwargs):
+            raise RuntimeError(message)
+
+        return _no_chain_background_task_co
+    if inspect.isasyncgenfunction(fn):
+
+        async def _no_chain_background_task_gen(*args, **kwargs):
+            yield
+            raise RuntimeError(message)
+
+        return _no_chain_background_task_gen
+
+    raise TypeError(f"{fn} is marked as a background task, but is not async.")
+
+
 class EventHandler(Base):
 class EventHandler(Base):
     """An event handler responds to an event to update the state."""
     """An event handler responds to an event to update the state."""
 
 
@@ -39,6 +112,15 @@ class EventHandler(Base):
         # Needed to allow serialization of Callable.
         # Needed to allow serialization of Callable.
         frozen = True
         frozen = True
 
 
+    @property
+    def is_background(self) -> bool:
+        """Whether the event handler is a background task.
+
+        Returns:
+            True if the event handler is marked as a background task.
+        """
+        return getattr(self.fn, BACKGROUND_TASK_MARKER, False)
+
     def __call__(self, *args: Var) -> EventSpec:
     def __call__(self, *args: Var) -> EventSpec:
         """Pass arguments to the handler to get an event spec.
         """Pass arguments to the handler to get an event spec.
 
 
@@ -530,7 +612,7 @@ def get_handler_args(event_spec: EventSpec) -> tuple[tuple[Var, Var], ...]:
 
 
 
 
 def fix_events(
 def fix_events(
-    events: list[EventHandler | EventSpec],
+    events: list[EventHandler | EventSpec] | None,
     token: str,
     token: str,
     router_data: dict[str, Any] | None = None,
     router_data: dict[str, Any] | None = None,
 ) -> list[Event]:
 ) -> list[Event]:

+ 546 - 69
reflex/state.py

@@ -2,13 +2,15 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
+import contextlib
 import copy
 import copy
 import functools
 import functools
 import inspect
 import inspect
 import json
 import json
 import traceback
 import traceback
 import urllib.parse
 import urllib.parse
-from abc import ABC
+import uuid
+from abc import ABC, abstractmethod
 from collections import defaultdict
 from collections import defaultdict
 from types import FunctionType
 from types import FunctionType
 from typing import (
 from typing import (
@@ -27,12 +29,20 @@ from typing import (
 import cloudpickle
 import cloudpickle
 import pydantic
 import pydantic
 import wrapt
 import wrapt
-from redis import Redis
+from redis.asyncio import Redis
 
 
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
-from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert
+from reflex.event import (
+    Event,
+    EventHandler,
+    EventSpec,
+    _no_chain_background_task,
+    fix_events,
+    window_alert,
+)
 from reflex.utils import format, prerequisites, types
 from reflex.utils import format, prerequisites, types
+from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
 from reflex.vars import BaseVar, ComputedVar, Var
 from reflex.vars import BaseVar, ComputedVar, Var
 
 
 Delta = Dict[str, Any]
 Delta = Dict[str, Any]
@@ -152,7 +162,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
 
         # Convert the event handlers to functions.
         # Convert the event handlers to functions.
         for name, event_handler in state.event_handlers.items():
         for name, event_handler in state.event_handlers.items():
-            fn = functools.partial(event_handler.fn, self)
+            if event_handler.is_background:
+                fn = _no_chain_background_task(type(state), name, event_handler.fn)
+            else:
+                fn = functools.partial(event_handler.fn, self)
             fn.__module__ = event_handler.fn.__module__  # type: ignore
             fn.__module__ = event_handler.fn.__module__  # type: ignore
             fn.__qualname__ = event_handler.fn.__qualname__  # type: ignore
             fn.__qualname__ = event_handler.fn.__qualname__  # type: ignore
             setattr(self, name, fn)
             setattr(self, name, fn)
@@ -711,52 +724,56 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             raise ValueError(f"Invalid path: {path}")
             raise ValueError(f"Invalid path: {path}")
         return self.substates[path[0]].get_substate(path[1:])
         return self.substates[path[0]].get_substate(path[1:])
 
 
-    async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
-        """Obtain event info and process event.
+    def _get_event_handler(
+        self, event: Event
+    ) -> tuple[State | StateProxy, EventHandler]:
+        """Get the event handler for the given event.
 
 
         Args:
         Args:
-            event: The event to process.
+            event: The event to get the handler for.
 
 
-        Yields:
-            The state update after processing the event.
+
+        Returns:
+            The event handler.
 
 
         Raises:
         Raises:
-            ValueError: If the state value is None.
+            ValueError: If the event handler or substate is not found.
         """
         """
         # Get the event handler.
         # Get the event handler.
         path = event.name.split(".")
         path = event.name.split(".")
         path, name = path[:-1], path[-1]
         path, name = path[:-1], path[-1]
         substate = self.get_substate(path)
         substate = self.get_substate(path)
-        handler = substate.event_handlers[name]  # type: ignore
-
         if not substate:
         if not substate:
             raise ValueError(
             raise ValueError(
                 "The value of state cannot be None when processing an event."
                 "The value of state cannot be None when processing an event."
             )
             )
+        handler = substate.event_handlers[name]
 
 
-        # Get the event generator.
-        event_iter = self._process_event(
-            handler=handler,
-            state=substate,
-            payload=event.payload,
-        )
+        # For background tasks, proxy the state
+        if handler.is_background:
+            substate = StateProxy(substate)
 
 
-        # Clean the state before processing the event.
-        self._clean()
+        return substate, handler
 
 
-        # Run the event generator and return state updates.
-        async for events, final in event_iter:
-            # Fix the returned events.
-            events = fix_events(events, event.token)  # type: ignore
+    async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
+        """Obtain event info and process event.
 
 
-            # Get the delta after processing the event.
-            delta = self.get_delta()
+        Args:
+            event: The event to process.
 
 
-            # Yield the state update.
-            yield StateUpdate(delta=delta, events=events, final=final)
+        Yields:
+            The state update after processing the event.
+        """
+        # Get the event handler.
+        substate, handler = self._get_event_handler(event)
 
 
-            # Clean the state to prepare for the next event.
-            self._clean()
+        # Run the event generator and yield state updates.
+        async for update in self._process_event(
+            handler=handler,
+            state=substate,
+            payload=event.payload,
+        ):
+            yield update
 
 
     def _check_valid(self, handler: EventHandler, events: Any) -> Any:
     def _check_valid(self, handler: EventHandler, events: Any) -> Any:
         """Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
         """Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
@@ -787,9 +804,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
             f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
         )
         )
 
 
+    def _as_state_update(
+        self,
+        handler: EventHandler,
+        events: EventSpec | list[EventSpec] | None,
+        final: bool,
+    ) -> StateUpdate:
+        """Convert the events to a StateUpdate.
+
+        Fixes the events and checks for validity before converting.
+
+        Args:
+            handler: The handler where the events originated from.
+            events: The events to queue with the update.
+            final: Whether the handler is done processing.
+
+        Returns:
+            The valid StateUpdate containing the events and final flag.
+        """
+        token = self.get_token()
+
+        # Convert valid EventHandler and EventSpec into Event
+        fixed_events = fix_events(self._check_valid(handler, events), token)
+
+        # Get the delta after processing the event.
+        delta = self.get_delta()
+        self._clean()
+
+        return StateUpdate(
+            delta=delta,
+            events=fixed_events,
+            final=final if not handler.is_background else True,
+        )
+
     async def _process_event(
     async def _process_event(
-        self, handler: EventHandler, state: State, payload: Dict
-    ) -> AsyncIterator[tuple[list[EventSpec] | None, bool]]:
+        self, handler: EventHandler, state: State | StateProxy, payload: Dict
+    ) -> AsyncIterator[StateUpdate]:
         """Process event.
         """Process event.
 
 
         Args:
         Args:
@@ -798,13 +848,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             payload: The event payload.
             payload: The event payload.
 
 
         Yields:
         Yields:
-            Tuple containing:
-                0: The state update after processing the event.
-                1: Whether the event is the final event.
+            StateUpdate object
         """
         """
         # Get the function to process the event.
         # Get the function to process the event.
         fn = functools.partial(handler.fn, state)
         fn = functools.partial(handler.fn, state)
 
 
+        # Clean the state before processing the event.
+        self._clean()
+
         # Wrap the function in a try/except block.
         # Wrap the function in a try/except block.
         try:
         try:
             # Handle async functions.
             # Handle async functions.
@@ -817,30 +868,34 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             # Handle async generators.
             # Handle async generators.
             if inspect.isasyncgen(events):
             if inspect.isasyncgen(events):
                 async for event in events:
                 async for event in events:
-                    yield self._check_valid(handler, event), False
-                yield None, True
+                    yield self._as_state_update(handler, event, final=False)
+                yield self._as_state_update(handler, events=None, final=True)
 
 
             # Handle regular generators.
             # Handle regular generators.
             elif inspect.isgenerator(events):
             elif inspect.isgenerator(events):
                 try:
                 try:
                     while True:
                     while True:
-                        yield self._check_valid(handler, next(events)), False
+                        yield self._as_state_update(handler, next(events), final=False)
                 except StopIteration as si:
                 except StopIteration as si:
                     # the "return" value of the generator is not available
                     # the "return" value of the generator is not available
                     # in the loop, we must catch StopIteration to access it
                     # in the loop, we must catch StopIteration to access it
                     if si.value is not None:
                     if si.value is not None:
-                        yield self._check_valid(handler, si.value), False
-                yield None, True
+                        yield self._as_state_update(handler, si.value, final=False)
+                yield self._as_state_update(handler, events=None, final=True)
 
 
             # Handle regular event chains.
             # Handle regular event chains.
             else:
             else:
-                yield self._check_valid(handler, events), True
+                yield self._as_state_update(handler, events, final=True)
 
 
         # If an error occurs, throw a window alert.
         # If an error occurs, throw a window alert.
         except Exception:
         except Exception:
             error = traceback.format_exc()
             error = traceback.format_exc()
             print(error)
             print(error)
-            yield [window_alert("An error occurred. See logs for details.")], True
+            yield self._as_state_update(
+                handler,
+                window_alert("An error occurred. See logs for details."),
+                final=True,
+            )
 
 
     def _always_dirty_computed_vars(self) -> set[str]:
     def _always_dirty_computed_vars(self) -> set[str]:
         """The set of ComputedVars that always need to be recalculated.
         """The set of ComputedVars that always need to be recalculated.
@@ -989,6 +1044,160 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         variables = {**base_vars, **computed_vars, **substate_vars}
         variables = {**base_vars, **computed_vars, **substate_vars}
         return {k: variables[k] for k in sorted(variables)}
         return {k: variables[k] for k in sorted(variables)}
 
 
+    async def __aenter__(self) -> State:
+        """Enter the async context manager protocol.
+
+        This should not be used for the State class, but exists for
+        type-compatibility with StateProxy.
+
+        Raises:
+            TypeError: always, because async contextmanager protocol is only supported for background task.
+        """
+        raise TypeError(
+            "Only background task should use `async with self` to modify state."
+        )
+
+    async def __aexit__(self, *exc_info: Any) -> None:
+        """Exit the async context manager protocol.
+
+        This should not be used for the State class, but exists for
+        type-compatibility with StateProxy.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        pass
+
+
+class StateProxy(wrapt.ObjectProxy):
+    """Proxy of a state instance to control mutability of vars for a background task.
+
+    Since a background task runs against a state instance without holding the
+    state_manager lock for the token, the reference may become stale if the same
+    state is modified by another event handler.
+
+    The proxy object ensures that writes to the state are blocked unless
+    explicitly entering a context which refreshes the state from state_manager
+    and holds the lock for the token until exiting the context. After exiting
+    the context, a StateUpdate may be emitted to the frontend to notify the
+    client of the state change.
+
+    A background task will be passed the `StateProxy` as `self`, so mutability
+    can be safely performed inside an `async with self` block.
+
+        class State(rx.State):
+            counter: int = 0
+
+            @rx.background
+            async def bg_increment(self):
+                await asyncio.sleep(1)
+                async with self:
+                    self.counter += 1
+    """
+
+    def __init__(self, state_instance):
+        """Create a proxy for a state instance.
+
+        Args:
+            state_instance: The state instance to proxy.
+        """
+        super().__init__(state_instance)
+        self._self_app = getattr(prerequisites.get_app(), constants.APP_VAR)
+        self._self_substate_path = state_instance.get_full_name().split(".")
+        self._self_actx = None
+        self._self_mutable = False
+
+    async def __aenter__(self) -> StateProxy:
+        """Enter the async context manager protocol.
+
+        Sets mutability to True and enters the `App.modify_state` async context,
+        which refreshes the state from state_manager and holds the lock for the
+        given state token until exiting the context.
+
+        Background tasks should avoid blocking calls while inside the context.
+
+        Returns:
+            This StateProxy instance in mutable mode.
+        """
+        self._self_actx = self._self_app.modify_state(self.__wrapped__.get_token())
+        mutable_state = await self._self_actx.__aenter__()
+        super().__setattr__(
+            "__wrapped__", mutable_state.get_substate(self._self_substate_path)
+        )
+        self._self_mutable = True
+        return self
+
+    async def __aexit__(self, *exc_info: Any) -> None:
+        """Exit the async context manager protocol.
+
+        Sets proxy mutability to False and persists any state changes.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        if self._self_actx is None:
+            return
+        self._self_mutable = False
+        await self._self_actx.__aexit__(*exc_info)
+        self._self_actx = None
+
+    def __enter__(self):
+        """Enter the regular context manager protocol.
+
+        This is not supported for background tasks, and exists only to raise a more useful exception
+        when the StateProxy is used incorrectly.
+
+        Raises:
+            TypeError: always, because only async contextmanager protocol is supported.
+        """
+        raise TypeError("Background task must use `async with self` to modify state.")
+
+    def __exit__(self, *exc_info: Any) -> None:
+        """Exit the regular context manager protocol.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        pass
+
+    def __getattr__(self, name: str) -> Any:
+        """Get the attribute from the underlying state instance.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The value of the attribute.
+        """
+        value = super().__getattr__(name)
+        if not name.startswith("_self_") and isinstance(value, MutableProxy):
+            # ensure mutations to these containers are blocked unless proxy is _mutable
+            return ImmutableMutableProxy(
+                wrapped=value.__wrapped__,
+                state=self,  # type: ignore
+                field_name=value._self_field_name,
+            )
+        return value
+
+    def __setattr__(self, name: str, value: Any) -> None:
+        """Set the attribute on the underlying state instance.
+
+        If the attribute is internal, set it on the proxy instance instead.
+
+        Args:
+            name: The name of the attribute.
+            value: The value of the attribute.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not name.startswith("_self_") and not self._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        super().__setattr__(name, value)
+
 
 
 class DefaultState(State):
 class DefaultState(State):
     """The default empty state."""
     """The default empty state."""
@@ -1009,31 +1218,83 @@ class StateUpdate(Base):
     final: bool = True
     final: bool = True
 
 
 
 
-class StateManager(Base):
+class StateManager(Base, ABC):
     """A class to manage many client states."""
     """A class to manage many client states."""
 
 
     # The state class to use.
     # The state class to use.
-    state: Type[State] = DefaultState
+    state: Type[State]
 
 
-    # The mapping of client ids to states.
-    states: Dict[str, State] = {}
+    @classmethod
+    def create(cls, state: Type[State] = DefaultState):
+        """Create a new state manager.
 
 
-    # The token expiration time (s).
-    token_expiration: int = constants.TOKEN_EXPIRATION
+        Args:
+            state: The state class to use.
 
 
-    # The redis client to use.
-    redis: Optional[Redis] = None
+        Returns:
+            The state manager (either memory or redis).
+        """
+        redis = prerequisites.get_redis()
+        if redis is not None:
+            return StateManagerRedis(state=state, redis=redis)
+        return StateManagerMemory(state=state)
 
 
-    def setup(self, state: Type[State]):
-        """Set up the state manager.
+    @abstractmethod
+    async def get_state(self, token: str) -> State:
+        """Get the state for a token.
 
 
         Args:
         Args:
-            state: The state class to use.
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        pass
+
+    @abstractmethod
+    async def set_state(self, token: str, state: State):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+        """
+        pass
+
+    @abstractmethod
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
         """
         """
-        self.state = state
-        self.redis = prerequisites.get_redis()
+        yield self.state()
 
 
-    def get_state(self, token: str) -> State:
+
+class StateManagerMemory(StateManager):
+    """A state manager that stores states in memory."""
+
+    # The mapping of client ids to states.
+    states: Dict[str, State] = {}
+
+    # The mutex ensures the dict of mutexes is updated exclusively
+    _state_manager_lock = asyncio.Lock()
+
+    # The dict of mutexes for each client
+    _states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
+
+    class Config:
+        """The Pydantic config."""
+
+        fields = {
+            "_states_locks": {"exclude": True},
+        }
+
+    async def get_state(self, token: str) -> State:
         """Get the state for a token.
         """Get the state for a token.
 
 
         Args:
         Args:
@@ -1042,27 +1303,212 @@ class StateManager(Base):
         Returns:
         Returns:
             The state for the token.
             The state for the token.
         """
         """
-        if self.redis is not None:
-            redis_state = self.redis.get(token)
-            if redis_state is None:
-                self.set_state(token, self.state())
-                return self.get_state(token)
-            return cloudpickle.loads(redis_state)
-
         if token not in self.states:
         if token not in self.states:
             self.states[token] = self.state()
             self.states[token] = self.state()
         return self.states[token]
         return self.states[token]
 
 
-    def set_state(self, token: str, state: State):
+    async def set_state(self, token: str, state: State):
         """Set the state for a token.
         """Set the state for a token.
 
 
         Args:
         Args:
             token: The token to set the state for.
             token: The token to set the state for.
             state: The state to set.
             state: The state to set.
         """
         """
-        if self.redis is None:
-            return
-        self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
+        pass
+
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        if token not in self._states_locks:
+            async with self._state_manager_lock:
+                if token not in self._states_locks:
+                    self._states_locks[token] = asyncio.Lock()
+
+        async with self._states_locks[token]:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state)
+
+
+class StateManagerRedis(StateManager):
+    """A state manager that stores states in redis."""
+
+    # The redis client to use.
+    redis: Redis
+
+    # The token expiration time (s).
+    token_expiration: int = constants.TOKEN_EXPIRATION
+
+    # The maximum time to hold a lock (ms).
+    lock_expiration: int = constants.LOCK_EXPIRATION
+
+    # The keyspace subscription string when redis is waiting for lock to be released
+    _redis_notify_keyspace_events: str = (
+        "K"  # Enable keyspace notifications (target a particular key)
+        "g"  # For generic commands (DEL, EXPIRE, etc)
+        "x"  # For expired events
+        "e"  # For evicted events (i.e. maxmemory exceeded)
+    )
+
+    # These events indicate that a lock is no longer held
+    _redis_keyspace_lock_release_events: Set[bytes] = {
+        b"del",
+        b"expire",
+        b"expired",
+        b"evicted",
+    }
+
+    async def get_state(self, token: str) -> State:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        redis_state = await self.redis.get(token)
+        if redis_state is None:
+            await self.set_state(token, self.state())
+            return await self.get_state(token)
+        return cloudpickle.loads(redis_state)
+
+    async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+            lock_id: If provided, the lock_key must be set to this value to set the state.
+
+        Raises:
+            LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
+        """
+        # check that we're holding the lock
+        if (
+            lock_id is not None
+            and await self.redis.get(self._lock_key(token)) != lock_id
+        ):
+            raise LockExpiredError(
+                f"Lock expired for token {token} while processing. Consider increasing "
+                f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
+                "or use `@rx.background` decorator for long-running tasks."
+            )
+        await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
+
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        async with self._lock(token) as lock_id:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state, lock_id)
+
+    @staticmethod
+    def _lock_key(token: str) -> bytes:
+        """Get the redis key for a token's lock.
+
+        Args:
+            token: The token to get the lock key for.
+
+        Returns:
+            The redis lock key for the token.
+        """
+        return f"{token}_lock".encode()
+
+    async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
+        """Try to get a redis lock for a token.
+
+        Args:
+            lock_key: The redis key for the lock.
+            lock_id: The ID of the lock.
+
+        Returns:
+            True if the lock was obtained.
+        """
+        return await self.redis.set(
+            lock_key,
+            lock_id,
+            px=self.lock_expiration,
+            nx=True,  # only set if it doesn't exist
+        )
+
+    async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
+        """Wait for a redis lock to be released via pubsub.
+
+        Coroutine will not return until the lock is obtained.
+
+        Args:
+            lock_key: The redis key for the lock.
+            lock_id: The ID of the lock.
+        """
+        state_is_locked = False
+        lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
+        # Enable keyspace notifications for the lock key, so we know when it is available.
+        await self.redis.config_set(
+            "notify-keyspace-events", self._redis_notify_keyspace_events
+        )
+        async with self.redis.pubsub() as pubsub:
+            await pubsub.psubscribe(lock_key_channel)
+            while not state_is_locked:
+                # wait for the lock to be released
+                while True:
+                    if not await self.redis.exists(lock_key):
+                        break  # key was removed, try to get the lock again
+                    message = await pubsub.get_message(
+                        ignore_subscribe_messages=True,
+                        timeout=self.lock_expiration / 1000.0,
+                    )
+                    if message is None:
+                        continue
+                    if message["data"] in self._redis_keyspace_lock_release_events:
+                        break
+                state_is_locked = await self._try_get_lock(lock_key, lock_id)
+
+    @contextlib.asynccontextmanager
+    async def _lock(self, token: str):
+        """Obtain a redis lock for a token.
+
+        Args:
+            token: The token to obtain a lock for.
+
+        Yields:
+            The ID of the lock (to be passed to set_state).
+
+        Raises:
+            LockExpiredError: If the lock has expired while processing the event.
+        """
+        lock_key = self._lock_key(token)
+        lock_id = uuid.uuid4().hex.encode()
+
+        if not await self._try_get_lock(lock_key, lock_id):
+            # Missed the fast-path to get lock, subscribe for lock delete/expire events
+            await self._wait_lock(lock_key, lock_id)
+        state_is_locked = True
+
+        try:
+            yield lock_id
+        except LockExpiredError:
+            state_is_locked = False
+            raise
+        finally:
+            if state_is_locked:
+                # only delete our lock
+                await self.redis.delete(lock_key)
 
 
 
 
 class ClientStorageBase:
 class ClientStorageBase:
@@ -1246,7 +1692,7 @@ class MutableProxy(wrapt.ObjectProxy):
             value, super().__getattribute__("__mutable_types__")
             value, super().__getattribute__("__mutable_types__")
         ) and __name not in ("__wrapped__", "_self_state"):
         ) and __name not in ("__wrapped__", "_self_state"):
             # Recursively wrap mutable attribute values retrieved through this proxy.
             # Recursively wrap mutable attribute values retrieved through this proxy.
-            return MutableProxy(
+            return type(self)(
                 wrapped=value,
                 wrapped=value,
                 state=self._self_state,
                 state=self._self_state,
                 field_name=self._self_field_name,
                 field_name=self._self_field_name,
@@ -1266,7 +1712,7 @@ class MutableProxy(wrapt.ObjectProxy):
         value = super().__getitem__(key)
         value = super().__getitem__(key)
         if isinstance(value, self.__mutable_types__):
         if isinstance(value, self.__mutable_types__):
             # Recursively wrap mutable items retrieved through this proxy.
             # Recursively wrap mutable items retrieved through this proxy.
-            return MutableProxy(
+            return type(self)(
                 wrapped=value,
                 wrapped=value,
                 state=self._self_state,
                 state=self._self_state,
                 field_name=self._self_field_name,
                 field_name=self._self_field_name,
@@ -1332,3 +1778,34 @@ class MutableProxy(wrapt.ObjectProxy):
             A deepcopy of the wrapped object, unconnected to the proxy.
             A deepcopy of the wrapped object, unconnected to the proxy.
         """
         """
         return copy.deepcopy(self.__wrapped__, memo=memo)
         return copy.deepcopy(self.__wrapped__, memo=memo)
+
+
+class ImmutableMutableProxy(MutableProxy):
+    """A proxy for a mutable object that tracks changes.
+
+    This wrapper comes from StateProxy, and will raise an exception if an attempt is made
+    to modify the wrapped object when the StateProxy is immutable.
+    """
+
+    def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None):
+        """Raise an exception when an attempt is made to modify the object.
+
+        Intended for use with `FunctionWrapper` from the `wrapt` library.
+
+        Args:
+            wrapped: The wrapped function.
+            instance: The instance of the wrapped function.
+            args: The args for the wrapped function.
+            kwargs: The kwargs for the wrapped function.
+
+        Raises:
+            ImmutableStateError: if the StateProxy is not mutable.
+        """
+        if not self._self_state._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        super()._mark_dirty(
+            wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
+        )

+ 122 - 30
reflex/testing.py

@@ -1,6 +1,7 @@
 """reflex.testing - tools for testing reflex apps."""
 """reflex.testing - tools for testing reflex apps."""
 from __future__ import annotations
 from __future__ import annotations
 
 
+import asyncio
 import contextlib
 import contextlib
 import dataclasses
 import dataclasses
 import inspect
 import inspect
@@ -19,14 +20,13 @@ import types
 from http.server import SimpleHTTPRequestHandler
 from http.server import SimpleHTTPRequestHandler
 from typing import (
 from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
-    Any,
+    AsyncIterator,
     Callable,
     Callable,
     Coroutine,
     Coroutine,
     Optional,
     Optional,
     Type,
     Type,
     TypeVar,
     TypeVar,
     Union,
     Union,
-    cast,
 )
 )
 
 
 import psutil
 import psutil
@@ -38,7 +38,7 @@ import reflex.utils.build
 import reflex.utils.exec
 import reflex.utils.exec
 import reflex.utils.prerequisites
 import reflex.utils.prerequisites
 import reflex.utils.processes
 import reflex.utils.processes
-from reflex.app import EventNamespace
+from reflex.state import State, StateManagerMemory, StateManagerRedis
 
 
 try:
 try:
     from selenium import webdriver  # pyright: ignore [reportMissingImports]
     from selenium import webdriver  # pyright: ignore [reportMissingImports]
@@ -109,6 +109,7 @@ class AppHarness:
     frontend_url: Optional[str] = None
     frontend_url: Optional[str] = None
     backend_thread: Optional[threading.Thread] = None
     backend_thread: Optional[threading.Thread] = None
     backend: Optional[uvicorn.Server] = None
     backend: Optional[uvicorn.Server] = None
+    state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
     _frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
     _frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
 
 
     @classmethod
     @classmethod
@@ -162,6 +163,27 @@ class AppHarness:
             reflex.config.get_config(reload=True)
             reflex.config.get_config(reload=True)
             self.app_module = reflex.utils.prerequisites.get_app(reload=True)
             self.app_module = reflex.utils.prerequisites.get_app(reload=True)
         self.app_instance = self.app_module.app
         self.app_instance = self.app_module.app
+        if isinstance(self.app_instance.state_manager, StateManagerRedis):
+            # Create our own redis connection for testing.
+            self.state_manager = StateManagerRedis.create(self.app_instance.state)
+        else:
+            self.state_manager = self.app_instance.state_manager
+
+    def _get_backend_shutdown_handler(self):
+        if self.backend is None:
+            raise RuntimeError("Backend was not initialized.")
+
+        original_shutdown = self.backend.shutdown
+
+        async def _shutdown_redis(*args, **kwargs) -> None:
+            # ensure redis is closed before event loop
+            if self.app_instance is not None and isinstance(
+                self.app_instance.state_manager, StateManagerRedis
+            ):
+                await self.app_instance.state_manager.redis.close()
+            await original_shutdown(*args, **kwargs)
+
+        return _shutdown_redis
 
 
     def _start_backend(self, port=0):
     def _start_backend(self, port=0):
         if self.app_instance is None:
         if self.app_instance is None:
@@ -173,6 +195,7 @@ class AppHarness:
                 port=port,
                 port=port,
             )
             )
         )
         )
+        self.backend.shutdown = self._get_backend_shutdown_handler()
         self.backend_thread = threading.Thread(target=self.backend.run)
         self.backend_thread = threading.Thread(target=self.backend.run)
         self.backend_thread.start()
         self.backend_thread.start()
 
 
@@ -296,6 +319,35 @@ class AppHarness:
             time.sleep(step)
             time.sleep(step)
         return False
         return False
 
 
+    @staticmethod
+    async def _poll_for_async(
+        target: Callable[[], Coroutine[None, None, T]],
+        timeout: TimeoutType = None,
+        step: TimeoutType = None,
+    ) -> T | bool:
+        """Generic polling logic for async functions.
+
+        Args:
+            target: callable that returns truthy if polling condition is met.
+            timeout: max polling time
+            step: interval between checking target()
+
+        Returns:
+            return value of target() if truthy within timeout
+            False if timeout elapses
+        """
+        if timeout is None:
+            timeout = DEFAULT_TIMEOUT
+        if step is None:
+            step = POLL_INTERVAL
+        deadline = time.time() + timeout
+        while time.time() < deadline:
+            success = await target()
+            if success:
+                return success
+            await asyncio.sleep(step)
+        return False
+
     def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
     def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
         """Poll backend server for listening sockets.
         """Poll backend server for listening sockets.
 
 
@@ -351,39 +403,76 @@ class AppHarness:
         self._frontends.append(driver)
         self._frontends.append(driver)
         return driver
         return driver
 
 
-    async def emit_state_updates(self) -> list[Any]:
-        """Send any backend state deltas to the frontend.
+    async def get_state(self, token: str) -> State:
+        """Get the state associated with the given token.
+
+        Args:
+            token: The state token to look up.
 
 
         Returns:
         Returns:
-            List of awaited response from each EventNamespace.emit() call.
+            The state instance associated with the given token
+
+        Raises:
+            RuntimeError: when the app hasn't started running
+        """
+        if self.state_manager is None:
+            raise RuntimeError("state_manager is not set.")
+        try:
+            return await self.state_manager.get_state(token)
+        finally:
+            if isinstance(self.state_manager, StateManagerRedis):
+                await self.state_manager.redis.close()
+
+    async def set_state(self, token: str, **kwargs) -> None:
+        """Set the state associated with the given token.
+
+        Args:
+            token: The state token to set.
+            kwargs: Attributes to set on the state.
+
+        Raises:
+            RuntimeError: when the app hasn't started running
+        """
+        if self.state_manager is None:
+            raise RuntimeError("state_manager is not set.")
+        state = await self.get_state(token)
+        for key, value in kwargs.items():
+            setattr(state, key, value)
+        try:
+            await self.state_manager.set_state(token, state)
+        finally:
+            if isinstance(self.state_manager, StateManagerRedis):
+                await self.state_manager.redis.close()
+
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[State]:
+        """Modify the state associated with the given token and send update to frontend.
+
+        Args:
+            token: The state token to modify
+
+        Yields:
+            The state instance associated with the given token
 
 
         Raises:
         Raises:
             RuntimeError: when the app hasn't started running
             RuntimeError: when the app hasn't started running
         """
         """
-        if self.app_instance is None or self.app_instance.sio is None:
+        if self.state_manager is None:
+            raise RuntimeError("state_manager is not set.")
+        if self.app_instance is None:
             raise RuntimeError("App is not running.")
             raise RuntimeError("App is not running.")
-        event_ns: EventNamespace = cast(
-            EventNamespace,
-            self.app_instance.event_namespace,
-        )
-        pending: list[Coroutine[Any, Any, Any]] = []
-        for state in self.app_instance.state_manager.states.values():
-            delta = state.get_delta()
-            if delta:
-                update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
-                state._clean()
-                # Emit the event.
-                pending.append(
-                    event_ns.emit(
-                        str(reflex.constants.SocketEvent.EVENT),
-                        update.json(),
-                        to=state.get_sid(),
-                    ),
-                )
-        responses = []
-        for request in pending:
-            responses.append(await request)
-        return responses
+        app_state_manager = self.app_instance.state_manager
+        if isinstance(self.state_manager, StateManagerRedis):
+            # Temporarily replace the app's state manager with our own, since
+            # the redis connection is on the backend_thread event loop
+            self.app_instance.state_manager = self.state_manager
+        try:
+            async with self.app_instance.modify_state(token) as state:
+                yield state
+        finally:
+            if isinstance(self.state_manager, StateManagerRedis):
+                self.app_instance.state_manager = app_state_manager
+                await self.state_manager.redis.close()
 
 
     def poll_for_content(
     def poll_for_content(
         self,
         self,
@@ -457,6 +546,9 @@ class AppHarness:
         if self.app_instance is None:
         if self.app_instance is None:
             raise RuntimeError("App is not running.")
             raise RuntimeError("App is not running.")
         state_manager = self.app_instance.state_manager
         state_manager = self.app_instance.state_manager
+        assert isinstance(
+            state_manager, StateManagerMemory
+        ), "Only works with memory state manager"
         if not self._poll_for(
         if not self._poll_for(
             target=lambda: state_manager.states,
             target=lambda: state_manager.states,
             timeout=timeout,
             timeout=timeout,
@@ -534,7 +626,6 @@ class Subdir404TCPServer(socketserver.TCPServer):
             request: the requesting socket
             request: the requesting socket
             client_address: (host, port) referring to the client’s address.
             client_address: (host, port) referring to the client’s address.
         """
         """
-        print(client_address, type(client_address))
         self.RequestHandlerClass(
         self.RequestHandlerClass(
             request,
             request,
             client_address,
             client_address,
@@ -605,6 +696,7 @@ class AppHarnessProd(AppHarness):
                 workers=reflex.utils.processes.get_num_workers(),
                 workers=reflex.utils.processes.get_num_workers(),
             ),
             ),
         )
         )
+        self.backend.shutdown = self._get_backend_shutdown_handler()
         self.backend_thread = threading.Thread(target=self.backend.run)
         self.backend_thread = threading.Thread(target=self.backend.run)
         self.backend_thread.start()
         self.backend_thread.start()
 
 

+ 8 - 0
reflex/utils/exceptions.py

@@ -5,3 +5,11 @@ class InvalidStylePropError(TypeError):
     """Custom Type Error when style props have invalid values."""
     """Custom Type Error when style props have invalid values."""
 
 
     pass
     pass
+
+
+class ImmutableStateError(AttributeError):
+    """Raised when a background task attempts to modify state outside of context."""
+
+
+class LockExpiredError(Exception):
+    """Raised when the state lock expires while an event is being processed."""

+ 5 - 3
reflex/utils/prerequisites.py

@@ -21,7 +21,7 @@ import httpx
 import typer
 import typer
 from alembic.util.exc import CommandError
 from alembic.util.exc import CommandError
 from packaging import version
 from packaging import version
-from redis import Redis
+from redis.asyncio import Redis
 
 
 from reflex import constants, model
 from reflex import constants, model
 from reflex.compiler import templates
 from reflex.compiler import templates
@@ -124,9 +124,11 @@ def get_redis() -> Redis | None:
         The redis client.
         The redis client.
     """
     """
     config = get_config()
     config = get_config()
-    if config.redis_url is None:
+    if not config.redis_url:
         return None
         return None
-    redis_url, redis_port = config.redis_url.split(":")
+    redis_url, has_port, redis_port = config.redis_url.partition(":")
+    if not has_port:
+        redis_port = 6379
     console.info(f"Using redis at {config.redis_url}")
     console.info(f"Using redis at {config.redis_url}")
     return Redis(host=redis_url, port=int(redis_port), db=0)
     return Redis(host=redis_url, port=int(redis_port), db=0)
 
 

+ 21 - 380
tests/conftest.py

@@ -2,8 +2,9 @@
 import contextlib
 import contextlib
 import os
 import os
 import platform
 import platform
+import uuid
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, Generator, List, Set, Union
+from typing import Dict, Generator
 
 
 import pytest
 import pytest
 
 
@@ -11,6 +12,14 @@ import reflex as rx
 from reflex.app import App
 from reflex.app import App
 from reflex.event import EventSpec
 from reflex.event import EventSpec
 
 
+from .states import (
+    DictMutationTestState,
+    ListMutationTestState,
+    MutableTestState,
+    SubUploadState,
+    UploadState,
+)
+
 
 
 @pytest.fixture
 @pytest.fixture
 def app() -> App:
 def app() -> App:
@@ -39,60 +48,7 @@ def list_mutation_state():
     Returns:
     Returns:
         A state with list mutation features.
         A state with list mutation features.
     """
     """
-
-    class TestState(rx.State):
-        """The test state."""
-
-        # plain list
-        plain_friends = ["Tommy"]
-
-        def make_friend(self):
-            self.plain_friends.append("another-fd")
-
-        def change_first_friend(self):
-            self.plain_friends[0] = "Jenny"
-
-        def unfriend_all_friends(self):
-            self.plain_friends.clear()
-
-        def unfriend_first_friend(self):
-            del self.plain_friends[0]
-
-        def remove_last_friend(self):
-            self.plain_friends.pop()
-
-        def make_friends_with_colleagues(self):
-            colleagues = ["Peter", "Jimmy"]
-            self.plain_friends.extend(colleagues)
-
-        def remove_tommy(self):
-            self.plain_friends.remove("Tommy")
-
-        # list in dict
-        friends_in_dict = {"Tommy": ["Jenny"]}
-
-        def remove_jenny_from_tommy(self):
-            self.friends_in_dict["Tommy"].remove("Jenny")
-
-        def add_jimmy_to_tommy_friends(self):
-            self.friends_in_dict["Tommy"].append("Jimmy")
-
-        def tommy_has_no_fds(self):
-            self.friends_in_dict["Tommy"].clear()
-
-        # nested list
-        friends_in_nested_list = [["Tommy"], ["Jenny"]]
-
-        def remove_first_group(self):
-            self.friends_in_nested_list.pop(0)
-
-        def remove_first_person_from_first_group(self):
-            self.friends_in_nested_list[0].pop(0)
-
-        def add_jimmy_to_second_group(self):
-            self.friends_in_nested_list[1].append("Jimmy")
-
-    return TestState()
+    return ListMutationTestState()
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -102,85 +58,7 @@ def dict_mutation_state():
     Returns:
     Returns:
         A state with dict mutation features.
         A state with dict mutation features.
     """
     """
-
-    class TestState(rx.State):
-        """The test state."""
-
-        # plain dict
-        details = {"name": "Tommy"}
-
-        def add_age(self):
-            self.details.update({"age": 20})  # type: ignore
-
-        def change_name(self):
-            self.details["name"] = "Jenny"
-
-        def remove_last_detail(self):
-            self.details.popitem()
-
-        def clear_details(self):
-            self.details.clear()
-
-        def remove_name(self):
-            del self.details["name"]
-
-        def pop_out_age(self):
-            self.details.pop("age")
-
-        # dict in list
-        address = [{"home": "home address"}, {"work": "work address"}]
-
-        def remove_home_address(self):
-            self.address[0].pop("home")
-
-        def add_street_to_home_address(self):
-            self.address[0]["street"] = "street address"
-
-        # nested dict
-        friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
-
-        def change_friend_name(self):
-            self.friend_in_nested_dict["friend"]["name"] = "Tommy"
-
-        def remove_friend(self):
-            self.friend_in_nested_dict.pop("friend")
-
-        def add_friend_age(self):
-            self.friend_in_nested_dict["friend"]["age"] = 30
-
-    return TestState()
-
-
-class UploadState(rx.State):
-    """The base state for uploading a file."""
-
-    async def handle_upload1(self, files: List[rx.UploadFile]):
-        """Handle the upload of a file.
-
-        Args:
-            files: The uploaded files.
-        """
-        pass
-
-
-class BaseState(rx.State):
-    """The test base state."""
-
-    pass
-
-
-class SubUploadState(BaseState):
-    """The test substate."""
-
-    img: str
-
-    async def handle_upload(self, files: List[rx.UploadFile]):
-        """Handle the upload of a file.
-
-        Args:
-            files: The uploaded files.
-        """
-        pass
+    return DictMutationTestState()
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -203,187 +81,6 @@ def upload_event_spec():
     return EventSpec(handler=UploadState.handle_upload1, upload=True)  # type: ignore
     return EventSpec(handler=UploadState.handle_upload1, upload=True)  # type: ignore
 
 
 
 
-@pytest.fixture
-def upload_state(tmp_path):
-    """Create upload state.
-
-    Args:
-        tmp_path: pytest tmp_path
-
-    Returns:
-        The state
-
-    """
-
-    class FileUploadState(rx.State):
-        """The base state for uploading a file."""
-
-        img_list: List[str]
-
-        async def handle_upload2(self, files):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                self.img_list.append(file.filename)
-
-        async def multi_handle_upload(self, files: List[rx.UploadFile]):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                assert file.filename is not None
-                self.img_list.append(file.filename)
-
-    return FileUploadState
-
-
-@pytest.fixture
-def upload_sub_state(tmp_path):
-    """Create upload substate.
-
-    Args:
-        tmp_path: pytest tmp_path
-
-    Returns:
-        The state
-
-    """
-
-    class FileState(rx.State):
-        """The base state."""
-
-        pass
-
-    class FileUploadState(FileState):
-        """The substate for uploading a file."""
-
-        img_list: List[str]
-
-        async def handle_upload2(self, files):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                self.img_list.append(file.filename)
-
-        async def multi_handle_upload(self, files: List[rx.UploadFile]):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                assert file.filename is not None
-                self.img_list.append(file.filename)
-
-    return FileUploadState
-
-
-@pytest.fixture
-def upload_grand_sub_state(tmp_path):
-    """Create upload grand-state.
-
-    Args:
-        tmp_path: pytest tmp_path
-
-    Returns:
-        The state
-
-    """
-
-    class BaseFileState(rx.State):
-        """The base state."""
-
-        pass
-
-    class FileSubState(BaseFileState):
-        """The substate."""
-
-        pass
-
-    class FileUploadState(FileSubState):
-        """The grand-substate for uploading a file."""
-
-        img_list: List[str]
-
-        async def handle_upload2(self, files):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                assert file.filename is not None
-                self.img_list.append(file.filename)
-
-        async def multi_handle_upload(self, files: List[rx.UploadFile]):
-            """Handle the upload of a file.
-
-            Args:
-                files: The uploaded files.
-            """
-            for file in files:
-                upload_data = await file.read()
-                outfile = f"{tmp_path}/{file.filename}"
-
-                # Save the file.
-                with open(outfile, "wb") as file_object:
-                    file_object.write(upload_data)
-
-                # Update the img var.
-                assert file.filename is not None
-                self.img_list.append(file.filename)
-
-    return FileUploadState
-
-
 @pytest.fixture
 @pytest.fixture
 def base_config_values() -> Dict:
 def base_config_values() -> Dict:
     """Get base config values.
     """Get base config values.
@@ -418,35 +115,6 @@ def sqlite_db_config_values(base_db_config_values) -> Dict:
     return base_db_config_values
     return base_db_config_values
 
 
 
 
-class GenState(rx.State):
-    """A state with event handlers that generate multiple updates."""
-
-    value: int
-
-    def go(self, c: int):
-        """Increment the value c times and update each time.
-
-        Args:
-            c: The number of times to increment.
-
-        Yields:
-            After each increment.
-        """
-        for _ in range(c):
-            self.value += 1
-            yield
-
-
-@pytest.fixture
-def gen_state() -> GenState:
-    """A state.
-
-    Returns:
-        A test state.
-    """
-    return GenState  # type: ignore
-
-
 @pytest.fixture
 @pytest.fixture
 def router_data_headers() -> Dict[str, str]:
 def router_data_headers() -> Dict[str, str]:
     """Router data headers.
     """Router data headers.
@@ -546,44 +214,17 @@ def mutable_state():
     Returns:
     Returns:
         A state object.
         A state object.
     """
     """
+    return MutableTestState()
 
 
-    class OtherBase(rx.Base):
-        bar: str = ""
-
-    class CustomVar(rx.Base):
-        foo: str = ""
-        array: List[str] = []
-        hashmap: Dict[str, str] = {}
-        test_set: Set[str] = set()
-        custom: OtherBase = OtherBase()
-
-    class MutableTestState(rx.State):
-        """A test state."""
-
-        array: List[Union[str, List, Dict[str, str]]] = [
-            "value",
-            [1, 2, 3],
-            {"key": "value"},
-        ]
-        hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
-            "key": ["list", "of", "values"],
-            "another_key": "another_value",
-            "third_key": {"key": "value"},
-        }
-        test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
-        custom: CustomVar = CustomVar()
-        _be_custom: CustomVar = CustomVar()
-
-        def reassign_mutables(self):
-            self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
-            self.hashmap = {
-                "mod_key": ["list", "of", "values"],
-                "mod_another_key": "another_value",
-                "mod_third_key": {"key": "value"},
-            }
-            self.test_set = {1, 2, 3, 4, "five"}
 
 
-    return MutableTestState()
+@pytest.fixture(scope="function")
+def token() -> str:
+    """Create a token.
+
+    Returns:
+        A fresh/unique token string.
+    """
+    return str(uuid.uuid4())
 
 
 
 
 @pytest.fixture
 @pytest.fixture

+ 30 - 0
tests/states/__init__.py

@@ -0,0 +1,30 @@
+"""Common rx.State subclasses for use in tests."""
+import reflex as rx
+
+from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
+from .upload import (
+    ChildFileUploadState,
+    FileUploadState,
+    GrandChildFileUploadState,
+    SubUploadState,
+    UploadState,
+)
+
+
+class GenState(rx.State):
+    """A state with event handlers that generate multiple updates."""
+
+    value: int
+
+    def go(self, c: int):
+        """Increment the value c times and update each time.
+
+        Args:
+            c: The number of times to increment.
+
+        Yields:
+            After each increment.
+        """
+        for _ in range(c):
+            self.value += 1
+            yield

+ 172 - 0
tests/states/mutation.py

@@ -0,0 +1,172 @@
+"""Test states for mutable vars."""
+
+from typing import Dict, List, Set, Union
+
+import reflex as rx
+
+
+class DictMutationTestState(rx.State):
+    """A state for testing ReflexDict mutation."""
+
+    # plain dict
+    details = {"name": "Tommy"}
+
+    def add_age(self):
+        """Add an age to the dict."""
+        self.details.update({"age": 20})  # type: ignore
+
+    def change_name(self):
+        """Change the name in the dict."""
+        self.details["name"] = "Jenny"
+
+    def remove_last_detail(self):
+        """Remove the last item in the dict."""
+        self.details.popitem()
+
+    def clear_details(self):
+        """Clear the dict."""
+        self.details.clear()
+
+    def remove_name(self):
+        """Remove the name from the dict."""
+        del self.details["name"]
+
+    def pop_out_age(self):
+        """Pop out the age from the dict."""
+        self.details.pop("age")
+
+    # dict in list
+    address = [{"home": "home address"}, {"work": "work address"}]
+
+    def remove_home_address(self):
+        """Remove the home address from dict in the list."""
+        self.address[0].pop("home")
+
+    def add_street_to_home_address(self):
+        """Set street key in the dict in the list."""
+        self.address[0]["street"] = "street address"
+
+    # nested dict
+    friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
+
+    def change_friend_name(self):
+        """Change the friend's name in the nested dict."""
+        self.friend_in_nested_dict["friend"]["name"] = "Tommy"
+
+    def remove_friend(self):
+        """Remove the friend from the nested dict."""
+        self.friend_in_nested_dict.pop("friend")
+
+    def add_friend_age(self):
+        """Add an age to the friend in the nested dict."""
+        self.friend_in_nested_dict["friend"]["age"] = 30
+
+
+class ListMutationTestState(rx.State):
+    """A state for testing ReflexList mutation."""
+
+    # plain list
+    plain_friends = ["Tommy"]
+
+    def make_friend(self):
+        """Add a friend to the list."""
+        self.plain_friends.append("another-fd")
+
+    def change_first_friend(self):
+        """Change the first friend in the list."""
+        self.plain_friends[0] = "Jenny"
+
+    def unfriend_all_friends(self):
+        """Unfriend all friends in the list."""
+        self.plain_friends.clear()
+
+    def unfriend_first_friend(self):
+        """Unfriend the first friend in the list."""
+        del self.plain_friends[0]
+
+    def remove_last_friend(self):
+        """Remove the last friend in the list."""
+        self.plain_friends.pop()
+
+    def make_friends_with_colleagues(self):
+        """Add list of friends to the list."""
+        colleagues = ["Peter", "Jimmy"]
+        self.plain_friends.extend(colleagues)
+
+    def remove_tommy(self):
+        """Remove Tommy from the list."""
+        self.plain_friends.remove("Tommy")
+
+    # list in dict
+    friends_in_dict = {"Tommy": ["Jenny"]}
+
+    def remove_jenny_from_tommy(self):
+        """Remove Jenny from Tommy's friends list."""
+        self.friends_in_dict["Tommy"].remove("Jenny")
+
+    def add_jimmy_to_tommy_friends(self):
+        """Add Jimmy to Tommy's friends list."""
+        self.friends_in_dict["Tommy"].append("Jimmy")
+
+    def tommy_has_no_fds(self):
+        """Clear Tommy's friends list."""
+        self.friends_in_dict["Tommy"].clear()
+
+    # nested list
+    friends_in_nested_list = [["Tommy"], ["Jenny"]]
+
+    def remove_first_group(self):
+        """Remove the first group of friends from the nested list."""
+        self.friends_in_nested_list.pop(0)
+
+    def remove_first_person_from_first_group(self):
+        """Remove the first person from the first group of friends in the nested list."""
+        self.friends_in_nested_list[0].pop(0)
+
+    def add_jimmy_to_second_group(self):
+        """Add Jimmy to the second group of friends in the nested list."""
+        self.friends_in_nested_list[1].append("Jimmy")
+
+
+class OtherBase(rx.Base):
+    """A Base model with a str field."""
+
+    bar: str = ""
+
+
+class CustomVar(rx.Base):
+    """A Base model with multiple fields."""
+
+    foo: str = ""
+    array: List[str] = []
+    hashmap: Dict[str, str] = {}
+    test_set: Set[str] = set()
+    custom: OtherBase = OtherBase()
+
+
+class MutableTestState(rx.State):
+    """A test state."""
+
+    array: List[Union[str, List, Dict[str, str]]] = [
+        "value",
+        [1, 2, 3],
+        {"key": "value"},
+    ]
+    hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
+        "key": ["list", "of", "values"],
+        "another_key": "another_value",
+        "third_key": {"key": "value"},
+    }
+    test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
+    custom: CustomVar = CustomVar()
+    _be_custom: CustomVar = CustomVar()
+
+    def reassign_mutables(self):
+        """Assign mutable fields to different values."""
+        self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
+        self.hashmap = {
+            "mod_key": ["list", "of", "values"],
+            "mod_another_key": "another_value",
+            "mod_third_key": {"key": "value"},
+        }
+        self.test_set = {1, 2, 3, 4, "five"}

+ 175 - 0
tests/states/upload.py

@@ -0,0 +1,175 @@
+"""Test states for upload-related tests."""
+from pathlib import Path
+from typing import ClassVar, List
+
+import reflex as rx
+
+
+class UploadState(rx.State):
+    """The base state for uploading a file."""
+
+    async def handle_upload1(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass
+
+
+class BaseState(rx.State):
+    """The test base state."""
+
+    pass
+
+
+class SubUploadState(BaseState):
+    """The test substate."""
+
+    img: str
+
+    async def handle_upload(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass
+
+
+class FileUploadState(rx.State):
+    """The base state for uploading a file."""
+
+    img_list: List[str]
+    _tmp_path: ClassVar[Path]
+
+    async def handle_upload2(self, files):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            self.img_list.append(file.filename)
+
+    async def multi_handle_upload(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            assert file.filename is not None
+            self.img_list.append(file.filename)
+
+
+class FileStateBase1(rx.State):
+    """The base state for a child FileUploadState."""
+
+    pass
+
+
+class ChildFileUploadState(FileStateBase1):
+    """The child state for uploading a file."""
+
+    img_list: List[str]
+    _tmp_path: ClassVar[Path]
+
+    async def handle_upload2(self, files):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            self.img_list.append(file.filename)
+
+    async def multi_handle_upload(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            assert file.filename is not None
+            self.img_list.append(file.filename)
+
+
+class FileStateBase2(FileStateBase1):
+    """The parent state for a grandchild FileUploadState."""
+
+    pass
+
+
+class GrandChildFileUploadState(FileStateBase2):
+    """The child state for uploading a file."""
+
+    img_list: List[str]
+    _tmp_path: ClassVar[Path]
+
+    async def handle_upload2(self, files):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            self.img_list.append(file.filename)
+
+    async def multi_handle_upload(self, files: List[rx.UploadFile]):
+        """Handle the upload of a file.
+
+        Args:
+            files: The uploaded files.
+        """
+        for file in files:
+            upload_data = await file.read()
+            outfile = f"{self._tmp_path}/{file.filename}"
+
+            # Save the file.
+            with open(outfile, "wb") as file_object:
+                file_object.write(upload_data)
+
+            # Update the img var.
+            assert file.filename is not None
+            self.img_list.append(file.filename)

+ 237 - 121
tests/test_app.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 import io
 import io
 import os.path
 import os.path
 import sys
 import sys
+import uuid
 from typing import List, Tuple, Type
 from typing import List, Tuple, Type
 
 
 if sys.version_info.major >= 3 and sys.version_info.minor > 7:
 if sys.version_info.major >= 3 and sys.version_info.minor > 7:
@@ -30,11 +31,18 @@ from reflex.components import Box, Component, Cond, Fragment, Text
 from reflex.event import Event, get_hydrate_event
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware import HydrateMiddleware
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
 from reflex.model import Model
-from reflex.state import State, StateUpdate
+from reflex.state import State, StateManagerRedis, StateUpdate
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import format
 from reflex.utils import format
 from reflex.vars import ComputedVar
 from reflex.vars import ComputedVar
 
 
+from .states import (
+    ChildFileUploadState,
+    FileUploadState,
+    GenState,
+    GrandChildFileUploadState,
+)
+
 
 
 @pytest.fixture
 @pytest.fixture
 def index_page():
 def index_page():
@@ -64,6 +72,12 @@ def about_page():
     return about
     return about
 
 
 
 
+class ATestState(State):
+    """A simple state for testing."""
+
+    var: int
+
+
 @pytest.fixture()
 @pytest.fixture()
 def test_state() -> Type[State]:
 def test_state() -> Type[State]:
     """A default state.
     """A default state.
@@ -71,11 +85,7 @@ def test_state() -> Type[State]:
     Returns:
     Returns:
         A default state.
         A default state.
     """
     """
-
-    class TestState(State):
-        var: int
-
-    return TestState
+    return ATestState
 
 
 
 
 @pytest.fixture()
 @pytest.fixture()
@@ -313,23 +323,28 @@ def test_initialize_admin_dashboard_with_view_overrides(test_model):
     assert app.admin_dash.view_overrides[test_model] == TestModelView
     assert app.admin_dash.view_overrides[test_model] == TestModelView
 
 
 
 
-def test_initialize_with_state(test_state):
+@pytest.mark.asyncio
+async def test_initialize_with_state(test_state: Type[ATestState], token: str):
     """Test setting the state of an app.
     """Test setting the state of an app.
 
 
     Args:
     Args:
         test_state: The default state.
         test_state: The default state.
+        token: a Token.
     """
     """
     app = App(state=test_state)
     app = App(state=test_state)
     assert app.state == test_state
     assert app.state == test_state
 
 
     # Get a state for a given token.
     # Get a state for a given token.
-    token = "token"
-    state = app.state_manager.get_state(token)
+    state = await app.state_manager.get_state(token)
     assert isinstance(state, test_state)
     assert isinstance(state, test_state)
     assert state.var == 0  # type: ignore
     assert state.var == 0  # type: ignore
 
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 
-def test_set_and_get_state(test_state):
+@pytest.mark.asyncio
+async def test_set_and_get_state(test_state):
     """Test setting and getting the state of an app with different tokens.
     """Test setting and getting the state of an app with different tokens.
 
 
     Args:
     Args:
@@ -338,47 +353,51 @@ def test_set_and_get_state(test_state):
     app = App(state=test_state)
     app = App(state=test_state)
 
 
     # Create two tokens.
     # Create two tokens.
-    token1 = "token1"
-    token2 = "token2"
+    token1 = str(uuid.uuid4())
+    token2 = str(uuid.uuid4())
 
 
     # Get the default state for each token.
     # Get the default state for each token.
-    state1 = app.state_manager.get_state(token1)
-    state2 = app.state_manager.get_state(token2)
+    state1 = await app.state_manager.get_state(token1)
+    state2 = await app.state_manager.get_state(token2)
     assert state1.var == 0  # type: ignore
     assert state1.var == 0  # type: ignore
     assert state2.var == 0  # type: ignore
     assert state2.var == 0  # type: ignore
 
 
     # Set the vars to different values.
     # Set the vars to different values.
     state1.var = 1
     state1.var = 1
     state2.var = 2
     state2.var = 2
-    app.state_manager.set_state(token1, state1)
-    app.state_manager.set_state(token2, state2)
+    await app.state_manager.set_state(token1, state1)
+    await app.state_manager.set_state(token2, state2)
 
 
     # Get the states again and check the values.
     # Get the states again and check the values.
-    state1 = app.state_manager.get_state(token1)
-    state2 = app.state_manager.get_state(token2)
+    state1 = await app.state_manager.get_state(token1)
+    state2 = await app.state_manager.get_state(token2)
     assert state1.var == 1  # type: ignore
     assert state1.var == 1  # type: ignore
     assert state2.var == 2  # type: ignore
     assert state2.var == 2  # type: ignore
 
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_dynamic_var_event(test_state):
+async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
     """Test that the default handler of a dynamic generated var
     """Test that the default handler of a dynamic generated var
     works as expected.
     works as expected.
 
 
     Args:
     Args:
         test_state: State Fixture.
         test_state: State Fixture.
+        token: a Token.
     """
     """
-    test_state = test_state()
-    test_state.add_var("int_val", int, 0)
-    result = await test_state._process(
+    state = test_state()  # type: ignore
+    state.add_var("int_val", int, 0)
+    result = await state._process(
         Event(
         Event(
-            token="fake-token",
-            name="test_state.set_int_val",
+            token=token,
+            name=f"{test_state.get_name()}.set_int_val",
             router_data={"pathname": "/", "query": {}},
             router_data={"pathname": "/", "query": {}},
             payload={"value": 50},
             payload={"value": 50},
         )
         )
     ).__anext__()
     ).__anext__()
-    assert result.delta == {"test_state": {"int_val": 50}}
+    assert result.delta == {test_state.get_name(): {"int_val": 50}}
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -388,12 +407,20 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.make_friend",
-                    {"test_state": {"plain_friends": ["Tommy", "another-fd"]}},
+                    "list_mutation_test_state.make_friend",
+                    {
+                        "list_mutation_test_state": {
+                            "plain_friends": ["Tommy", "another-fd"]
+                        }
+                    },
                 ),
                 ),
                 (
                 (
-                    "test_state.change_first_friend",
-                    {"test_state": {"plain_friends": ["Jenny", "another-fd"]}},
+                    "list_mutation_test_state.change_first_friend",
+                    {
+                        "list_mutation_test_state": {
+                            "plain_friends": ["Jenny", "another-fd"]
+                        }
+                    },
                 ),
                 ),
             ],
             ],
             id="append then __setitem__",
             id="append then __setitem__",
@@ -401,12 +428,12 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.unfriend_first_friend",
-                    {"test_state": {"plain_friends": []}},
+                    "list_mutation_test_state.unfriend_first_friend",
+                    {"list_mutation_test_state": {"plain_friends": []}},
                 ),
                 ),
                 (
                 (
-                    "test_state.make_friend",
-                    {"test_state": {"plain_friends": ["another-fd"]}},
+                    "list_mutation_test_state.make_friend",
+                    {"list_mutation_test_state": {"plain_friends": ["another-fd"]}},
                 ),
                 ),
             ],
             ],
             id="delitem then append",
             id="delitem then append",
@@ -414,20 +441,24 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.make_friends_with_colleagues",
-                    {"test_state": {"plain_friends": ["Tommy", "Peter", "Jimmy"]}},
+                    "list_mutation_test_state.make_friends_with_colleagues",
+                    {
+                        "list_mutation_test_state": {
+                            "plain_friends": ["Tommy", "Peter", "Jimmy"]
+                        }
+                    },
                 ),
                 ),
                 (
                 (
-                    "test_state.remove_tommy",
-                    {"test_state": {"plain_friends": ["Peter", "Jimmy"]}},
+                    "list_mutation_test_state.remove_tommy",
+                    {"list_mutation_test_state": {"plain_friends": ["Peter", "Jimmy"]}},
                 ),
                 ),
                 (
                 (
-                    "test_state.remove_last_friend",
-                    {"test_state": {"plain_friends": ["Peter"]}},
+                    "list_mutation_test_state.remove_last_friend",
+                    {"list_mutation_test_state": {"plain_friends": ["Peter"]}},
                 ),
                 ),
                 (
                 (
-                    "test_state.unfriend_all_friends",
-                    {"test_state": {"plain_friends": []}},
+                    "list_mutation_test_state.unfriend_all_friends",
+                    {"list_mutation_test_state": {"plain_friends": []}},
                 ),
                 ),
             ],
             ],
             id="extend, remove, pop, clear",
             id="extend, remove, pop, clear",
@@ -435,24 +466,28 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.add_jimmy_to_second_group",
+                    "list_mutation_test_state.add_jimmy_to_second_group",
                     {
                     {
-                        "test_state": {
+                        "list_mutation_test_state": {
                             "friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
                             "friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
                         }
                         }
                     },
                     },
                 ),
                 ),
                 (
                 (
-                    "test_state.remove_first_person_from_first_group",
+                    "list_mutation_test_state.remove_first_person_from_first_group",
                     {
                     {
-                        "test_state": {
+                        "list_mutation_test_state": {
                             "friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
                             "friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
                         }
                         }
                     },
                     },
                 ),
                 ),
                 (
                 (
-                    "test_state.remove_first_group",
-                    {"test_state": {"friends_in_nested_list": [["Jenny", "Jimmy"]]}},
+                    "list_mutation_test_state.remove_first_group",
+                    {
+                        "list_mutation_test_state": {
+                            "friends_in_nested_list": [["Jenny", "Jimmy"]]
+                        }
+                    },
                 ),
                 ),
             ],
             ],
             id="nested list",
             id="nested list",
@@ -460,16 +495,24 @@ async def test_dynamic_var_event(test_state):
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.add_jimmy_to_tommy_friends",
-                    {"test_state": {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}},
+                    "list_mutation_test_state.add_jimmy_to_tommy_friends",
+                    {
+                        "list_mutation_test_state": {
+                            "friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}
+                        }
+                    },
                 ),
                 ),
                 (
                 (
-                    "test_state.remove_jenny_from_tommy",
-                    {"test_state": {"friends_in_dict": {"Tommy": ["Jimmy"]}}},
+                    "list_mutation_test_state.remove_jenny_from_tommy",
+                    {
+                        "list_mutation_test_state": {
+                            "friends_in_dict": {"Tommy": ["Jimmy"]}
+                        }
+                    },
                 ),
                 ),
                 (
                 (
-                    "test_state.tommy_has_no_fds",
-                    {"test_state": {"friends_in_dict": {"Tommy": []}}},
+                    "list_mutation_test_state.tommy_has_no_fds",
+                    {"list_mutation_test_state": {"friends_in_dict": {"Tommy": []}}},
                 ),
                 ),
             ],
             ],
             id="list in dict",
             id="list in dict",
@@ -477,7 +520,9 @@ async def test_dynamic_var_event(test_state):
     ],
     ],
 )
 )
 async def test_list_mutation_detection__plain_list(
 async def test_list_mutation_detection__plain_list(
-    event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State
+    event_tuples: List[Tuple[str, List[str]]],
+    list_mutation_state: State,
+    token: str,
 ):
 ):
     """Test list mutation detection
     """Test list mutation detection
     when reassignment is not explicitly included in the logic.
     when reassignment is not explicitly included in the logic.
@@ -485,11 +530,12 @@ async def test_list_mutation_detection__plain_list(
     Args:
     Args:
         event_tuples: From parametrization.
         event_tuples: From parametrization.
         list_mutation_state: A state with list mutation features.
         list_mutation_state: A state with list mutation features.
+        token: a Token.
     """
     """
     for event_name, expected_delta in event_tuples:
     for event_name, expected_delta in event_tuples:
         result = await list_mutation_state._process(
         result = await list_mutation_state._process(
             Event(
             Event(
-                token="fake-token",
+                token=token,
                 name=event_name,
                 name=event_name,
                 router_data={"pathname": "/", "query": {}},
                 router_data={"pathname": "/", "query": {}},
                 payload={},
                 payload={},
@@ -506,16 +552,24 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.add_age",
-                    {"test_state": {"details": {"name": "Tommy", "age": 20}}},
+                    "dict_mutation_test_state.add_age",
+                    {
+                        "dict_mutation_test_state": {
+                            "details": {"name": "Tommy", "age": 20}
+                        }
+                    },
                 ),
                 ),
                 (
                 (
-                    "test_state.change_name",
-                    {"test_state": {"details": {"name": "Jenny", "age": 20}}},
+                    "dict_mutation_test_state.change_name",
+                    {
+                        "dict_mutation_test_state": {
+                            "details": {"name": "Jenny", "age": 20}
+                        }
+                    },
                 ),
                 ),
                 (
                 (
-                    "test_state.remove_last_detail",
-                    {"test_state": {"details": {"name": "Jenny"}}},
+                    "dict_mutation_test_state.remove_last_detail",
+                    {"dict_mutation_test_state": {"details": {"name": "Jenny"}}},
                 ),
                 ),
             ],
             ],
             id="update then __setitem__",
             id="update then __setitem__",
@@ -523,12 +577,12 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.clear_details",
-                    {"test_state": {"details": {}}},
+                    "dict_mutation_test_state.clear_details",
+                    {"dict_mutation_test_state": {"details": {}}},
                 ),
                 ),
                 (
                 (
-                    "test_state.add_age",
-                    {"test_state": {"details": {"age": 20}}},
+                    "dict_mutation_test_state.add_age",
+                    {"dict_mutation_test_state": {"details": {"age": 20}}},
                 ),
                 ),
             ],
             ],
             id="delitem then update",
             id="delitem then update",
@@ -536,16 +590,20 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.add_age",
-                    {"test_state": {"details": {"name": "Tommy", "age": 20}}},
+                    "dict_mutation_test_state.add_age",
+                    {
+                        "dict_mutation_test_state": {
+                            "details": {"name": "Tommy", "age": 20}
+                        }
+                    },
                 ),
                 ),
                 (
                 (
-                    "test_state.remove_name",
-                    {"test_state": {"details": {"age": 20}}},
+                    "dict_mutation_test_state.remove_name",
+                    {"dict_mutation_test_state": {"details": {"age": 20}}},
                 ),
                 ),
                 (
                 (
-                    "test_state.pop_out_age",
-                    {"test_state": {"details": {}}},
+                    "dict_mutation_test_state.pop_out_age",
+                    {"dict_mutation_test_state": {"details": {}}},
                 ),
                 ),
             ],
             ],
             id="add, remove, pop",
             id="add, remove, pop",
@@ -553,13 +611,17 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.remove_home_address",
-                    {"test_state": {"address": [{}, {"work": "work address"}]}},
+                    "dict_mutation_test_state.remove_home_address",
+                    {
+                        "dict_mutation_test_state": {
+                            "address": [{}, {"work": "work address"}]
+                        }
+                    },
                 ),
                 ),
                 (
                 (
-                    "test_state.add_street_to_home_address",
+                    "dict_mutation_test_state.add_street_to_home_address",
                     {
                     {
-                        "test_state": {
+                        "dict_mutation_test_state": {
                             "address": [
                             "address": [
                                 {"street": "street address"},
                                 {"street": "street address"},
                                 {"work": "work address"},
                                 {"work": "work address"},
@@ -573,9 +635,9 @@ async def test_list_mutation_detection__plain_list(
         pytest.param(
         pytest.param(
             [
             [
                 (
                 (
-                    "test_state.change_friend_name",
+                    "dict_mutation_test_state.change_friend_name",
                     {
                     {
-                        "test_state": {
+                        "dict_mutation_test_state": {
                             "friend_in_nested_dict": {
                             "friend_in_nested_dict": {
                                 "name": "Nikhil",
                                 "name": "Nikhil",
                                 "friend": {"name": "Tommy"},
                                 "friend": {"name": "Tommy"},
@@ -584,9 +646,9 @@ async def test_list_mutation_detection__plain_list(
                     },
                     },
                 ),
                 ),
                 (
                 (
-                    "test_state.add_friend_age",
+                    "dict_mutation_test_state.add_friend_age",
                     {
                     {
-                        "test_state": {
+                        "dict_mutation_test_state": {
                             "friend_in_nested_dict": {
                             "friend_in_nested_dict": {
                                 "name": "Nikhil",
                                 "name": "Nikhil",
                                 "friend": {"name": "Tommy", "age": 30},
                                 "friend": {"name": "Tommy", "age": 30},
@@ -595,8 +657,12 @@ async def test_list_mutation_detection__plain_list(
                     },
                     },
                 ),
                 ),
                 (
                 (
-                    "test_state.remove_friend",
-                    {"test_state": {"friend_in_nested_dict": {"name": "Nikhil"}}},
+                    "dict_mutation_test_state.remove_friend",
+                    {
+                        "dict_mutation_test_state": {
+                            "friend_in_nested_dict": {"name": "Nikhil"}
+                        }
+                    },
                 ),
                 ),
             ],
             ],
             id="nested dict",
             id="nested dict",
@@ -604,7 +670,9 @@ async def test_list_mutation_detection__plain_list(
     ],
     ],
 )
 )
 async def test_dict_mutation_detection__plain_list(
 async def test_dict_mutation_detection__plain_list(
-    event_tuples: List[Tuple[str, List[str]]], dict_mutation_state: State
+    event_tuples: List[Tuple[str, List[str]]],
+    dict_mutation_state: State,
+    token: str,
 ):
 ):
     """Test dict mutation detection
     """Test dict mutation detection
     when reassignment is not explicitly included in the logic.
     when reassignment is not explicitly included in the logic.
@@ -612,11 +680,12 @@ async def test_dict_mutation_detection__plain_list(
     Args:
     Args:
         event_tuples: From parametrization.
         event_tuples: From parametrization.
         dict_mutation_state: A state with dict mutation features.
         dict_mutation_state: A state with dict mutation features.
+        token: a Token.
     """
     """
     for event_name, expected_delta in event_tuples:
     for event_name, expected_delta in event_tuples:
         result = await dict_mutation_state._process(
         result = await dict_mutation_state._process(
             Event(
             Event(
-                token="fake-token",
+                token=token,
                 name=event_name,
                 name=event_name,
                 router_data={"pathname": "/", "query": {}},
                 router_data={"pathname": "/", "query": {}},
                 payload={},
                 payload={},
@@ -628,41 +697,43 @@ async def test_dict_mutation_detection__plain_list(
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
-    "fixture, delta",
+    ("state", "delta"),
     [
     [
         (
         (
-            "upload_state",
+            FileUploadState,
             {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
             {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
         ),
         ),
         (
         (
-            "upload_sub_state",
+            ChildFileUploadState,
             {
             {
-                "file_state.file_upload_state": {
+                "file_state_base1.child_file_upload_state": {
                     "img_list": ["image1.jpg", "image2.jpg"]
                     "img_list": ["image1.jpg", "image2.jpg"]
                 }
                 }
             },
             },
         ),
         ),
         (
         (
-            "upload_grand_sub_state",
+            GrandChildFileUploadState,
             {
             {
-                "base_file_state.file_sub_state.file_upload_state": {
+                "file_state_base1.file_state_base2.grand_child_file_upload_state": {
                     "img_list": ["image1.jpg", "image2.jpg"]
                     "img_list": ["image1.jpg", "image2.jpg"]
                 }
                 }
             },
             },
         ),
         ),
     ],
     ],
 )
 )
-async def test_upload_file(fixture, request, delta):
+async def test_upload_file(tmp_path, state, delta, token: str):
     """Test that file upload works correctly.
     """Test that file upload works correctly.
 
 
     Args:
     Args:
-        fixture: The state.
-        request: Fixture request.
+        tmp_path: Temporary path.
+        state: The state class.
         delta: Expected delta
         delta: Expected delta
+        token: a Token.
     """
     """
-    app = App(state=request.getfixturevalue(fixture))
+    state._tmp_path = tmp_path
+    app = App(state=state)
     app.event_namespace.emit = AsyncMock()  # type: ignore
     app.event_namespace.emit = AsyncMock()  # type: ignore
-    current_state = app.state_manager.get_state("token")
+    current_state = await app.state_manager.get_state(token)
     data = b"This is binary data"
     data = b"This is binary data"
 
 
     # Create a binary IO object and write data to it
     # Create a binary IO object and write data to it
@@ -670,11 +741,11 @@ async def test_upload_file(fixture, request, delta):
     bio.write(data)
     bio.write(data)
 
 
     file1 = UploadFile(
     file1 = UploadFile(
-        filename="token:file_upload_state.multi_handle_upload:True:image1.jpg",
+        filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image1.jpg",
         file=bio,
         file=bio,
     )
     )
     file2 = UploadFile(
     file2 = UploadFile(
-        filename="token:file_upload_state.multi_handle_upload:True:image2.jpg",
+        filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image2.jpg",
         file=bio,
         file=bio,
     )
     )
     upload_fn = upload(app)
     upload_fn = upload(app)
@@ -684,22 +755,27 @@ async def test_upload_file(fixture, request, delta):
     app.event_namespace.emit.assert_called_with(  # type: ignore
     app.event_namespace.emit.assert_called_with(  # type: ignore
         "event", state_update.json(), to=current_state.get_sid()
         "event", state_update.json(), to=current_state.get_sid()
     )
     )
-    assert app.state_manager.get_state("token").dict()["img_list"] == [
+    assert (await app.state_manager.get_state(token)).dict()["img_list"] == [
         "image1.jpg",
         "image1.jpg",
         "image2.jpg",
         "image2.jpg",
     ]
     ]
 
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
-    "fixture", ["upload_state", "upload_sub_state", "upload_grand_sub_state"]
+    "state",
+    [FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
 )
 )
-async def test_upload_file_without_annotation(fixture, request):
+async def test_upload_file_without_annotation(state, tmp_path, token):
     """Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile].
     """Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile].
 
 
     Args:
     Args:
-        fixture: The state.
-        request: Fixture request.
+        state: The state class.
+        tmp_path: Temporary path.
+        token: a Token.
     """
     """
     data = b"This is binary data"
     data = b"This is binary data"
 
 
@@ -707,14 +783,15 @@ async def test_upload_file_without_annotation(fixture, request):
     bio = io.BytesIO()
     bio = io.BytesIO()
     bio.write(data)
     bio.write(data)
 
 
-    app = App(state=request.getfixturevalue(fixture))
+    state._tmp_path = tmp_path
+    app = App(state=state)
 
 
     file1 = UploadFile(
     file1 = UploadFile(
-        filename="token:file_upload_state.handle_upload2:True:image1.jpg",
+        filename=f"{token}:{state.get_name()}.handle_upload2:True:image1.jpg",
         file=bio,
         file=bio,
     )
     )
     file2 = UploadFile(
     file2 = UploadFile(
-        filename="token:file_upload_state.handle_upload2:True:image2.jpg",
+        filename=f"{token}:{state.get_name()}.handle_upload2:True:image2.jpg",
         file=bio,
         file=bio,
     )
     )
     fn = upload(app)
     fn = upload(app)
@@ -722,9 +799,12 @@ async def test_upload_file_without_annotation(fixture, request):
         await fn([file1, file2])
         await fn([file1, file2])
     assert (
     assert (
         err.value.args[0]
         err.value.args[0]
-        == "`file_upload_state.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
+        == f"`{state.get_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
     )
     )
 
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 
 class DynamicState(State):
 class DynamicState(State):
     """State class for testing dynamic route var.
     """State class for testing dynamic route var.
@@ -768,6 +848,7 @@ class DynamicState(State):
 async def test_dynamic_route_var_route_change_completed_on_load(
 async def test_dynamic_route_var_route_change_completed_on_load(
     index_page,
     index_page,
     windows_platform: bool,
     windows_platform: bool,
+    token: str,
 ):
 ):
     """Create app with dynamic route var, and simulate navigation.
     """Create app with dynamic route var, and simulate navigation.
 
 
@@ -777,6 +858,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     Args:
     Args:
         index_page: The index page.
         index_page: The index page.
         windows_platform: Whether the system is windows.
         windows_platform: Whether the system is windows.
+        token: a Token.
     """
     """
     arg_name = "dynamic"
     arg_name = "dynamic"
     route = f"/test/[{arg_name}]"
     route = f"/test/[{arg_name}]"
@@ -792,10 +874,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     }
     }
     assert constants.ROUTER_DATA in app.state().computed_var_dependencies
     assert constants.ROUTER_DATA in app.state().computed_var_dependencies
 
 
-    token = "mock_token"
     sid = "mock_sid"
     sid = "mock_sid"
     client_ip = "127.0.0.1"
     client_ip = "127.0.0.1"
-    state = app.state_manager.get_state(token)
+    state = await app.state_manager.get_state(token)
     assert state.dynamic == ""
     assert state.dynamic == ""
     exp_vals = ["foo", "foobar", "baz"]
     exp_vals = ["foo", "foobar", "baz"]
 
 
@@ -817,6 +898,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             **kwargs,
             **kwargs,
         )
         )
 
 
+    prev_exp_val = ""
     for exp_index, exp_val in enumerate(exp_vals):
     for exp_index, exp_val in enumerate(exp_vals):
         hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
         hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
         exp_router_data = {
         exp_router_data = {
@@ -826,13 +908,14 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             "token": token,
             "token": token,
             **hydrate_event.router_data,
             **hydrate_event.router_data,
         }
         }
-        update = await process(
+        process_coro = process(
             app,
             app,
             event=hydrate_event,
             event=hydrate_event,
             sid=sid,
             sid=sid,
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
-        ).__anext__()  # type: ignore
+        )
+        update = await process_coro.__anext__()  # type: ignore
 
 
         # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
         # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
         assert update == StateUpdate(
         assert update == StateUpdate(
@@ -860,14 +943,27 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                 ),
                 ),
             ],
             ],
         )
         )
+        if isinstance(app.state_manager, StateManagerRedis):
+            # When redis is used, the state is not updated until the processing is complete
+            state = await app.state_manager.get_state(token)
+            assert state.dynamic == prev_exp_val
+
+        # complete the processing
+        with pytest.raises(StopAsyncIteration):
+            await process_coro.__anext__()  # type: ignore
+
+        # check that router data was written to the state_manager store
+        state = await app.state_manager.get_state(token)
         assert state.dynamic == exp_val
         assert state.dynamic == exp_val
-        on_load_update = await process(
+
+        process_coro = process(
             app,
             app,
             event=_dynamic_state_event(name="on_load", val=exp_val),
             event=_dynamic_state_event(name="on_load", val=exp_val),
             sid=sid,
             sid=sid,
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
-        ).__anext__()  # type: ignore
+        )
+        on_load_update = await process_coro.__anext__()  # type: ignore
         assert on_load_update == StateUpdate(
         assert on_load_update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {
@@ -879,7 +975,10 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             },
             },
             events=[],
             events=[],
         )
         )
-        on_set_is_hydrated_update = await process(
+        # complete the processing
+        with pytest.raises(StopAsyncIteration):
+            await process_coro.__anext__()  # type: ignore
+        process_coro = process(
             app,
             app,
             event=_dynamic_state_event(
             event=_dynamic_state_event(
                 name="set_is_hydrated", payload={"value": True}, val=exp_val
                 name="set_is_hydrated", payload={"value": True}, val=exp_val
@@ -887,7 +986,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             sid=sid,
             sid=sid,
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
-        ).__anext__()  # type: ignore
+        )
+        on_set_is_hydrated_update = await process_coro.__anext__()  # type: ignore
         assert on_set_is_hydrated_update == StateUpdate(
         assert on_set_is_hydrated_update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {
@@ -899,15 +999,19 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             },
             },
             events=[],
             events=[],
         )
         )
+        # complete the processing
+        with pytest.raises(StopAsyncIteration):
+            await process_coro.__anext__()  # type: ignore
 
 
         # a simple state update event should NOT trigger on_load or route var side effects
         # a simple state update event should NOT trigger on_load or route var side effects
-        update = await process(
+        process_coro = process(
             app,
             app,
             event=_dynamic_state_event(name="on_counter", val=exp_val),
             event=_dynamic_state_event(name="on_counter", val=exp_val),
             sid=sid,
             sid=sid,
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
-        ).__anext__()  # type: ignore
+        )
+        update = await process_coro.__anext__()  # type: ignore
         assert update == StateUpdate(
         assert update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {
@@ -919,42 +1023,54 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             },
             },
             events=[],
             events=[],
         )
         )
+        # complete the processing
+        with pytest.raises(StopAsyncIteration):
+            await process_coro.__anext__()  # type: ignore
+
+        prev_exp_val = exp_val
+    state = await app.state_manager.get_state(token)
     assert state.loaded == len(exp_vals)
     assert state.loaded == len(exp_vals)
     assert state.counter == len(exp_vals)
     assert state.counter == len(exp_vals)
     # print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
     # print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
     # assert state.side_effect_counter == len(exp_vals)
     # assert state.side_effect_counter == len(exp_vals)
 
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_process_events(gen_state, mocker):
+async def test_process_events(mocker, token: str):
     """Test that an event is processed properly and that it is postprocessed
     """Test that an event is processed properly and that it is postprocessed
     n+1 times. Also check that the processing flag of the last stateupdate is set to
     n+1 times. Also check that the processing flag of the last stateupdate is set to
     False.
     False.
 
 
     Args:
     Args:
-        gen_state: The state.
         mocker: mocker object.
         mocker: mocker object.
+        token: a Token.
     """
     """
     router_data = {
     router_data = {
         "pathname": "/",
         "pathname": "/",
         "query": {},
         "query": {},
-        "token": "mock_token",
+        "token": token,
         "sid": "mock_sid",
         "sid": "mock_sid",
         "headers": {},
         "headers": {},
         "ip": "127.0.0.1",
         "ip": "127.0.0.1",
     }
     }
-    app = App(state=gen_state)
+    app = App(state=GenState)
     mocker.patch.object(app, "postprocess", AsyncMock())
     mocker.patch.object(app, "postprocess", AsyncMock())
     event = Event(
     event = Event(
-        token="token", name="gen_state.go", payload={"c": 5}, router_data=router_data
+        token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
     )
     )
 
 
     async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):  # type: ignore
     async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):  # type: ignore
         pass
         pass
 
 
-    assert app.state_manager.get_state("token").value == 5
+    assert (await app.state_manager.get_state(token)).value == 5
     assert app.postprocess.call_count == 6
     assert app.postprocess.call_count == 6
 
 
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     ("state", "overlay_component", "exp_page_child"),
     ("state", "overlay_component", "exp_page_child"),

+ 417 - 11
tests/test_state.py

@@ -1,22 +1,42 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
+import asyncio
 import copy
 import copy
 import datetime
 import datetime
 import functools
 import functools
+import json
+import os
 import sys
 import sys
-from typing import Dict, List
+from typing import Dict, Generator, List
+from unittest.mock import AsyncMock, Mock
 
 
 import pytest
 import pytest
 from plotly.graph_objects import Figure
 from plotly.graph_objects import Figure
 
 
 import reflex as rx
 import reflex as rx
 from reflex.base import Base
 from reflex.base import Base
-from reflex.constants import IS_HYDRATED, RouteVar
+from reflex.constants import APP_VAR, IS_HYDRATED, RouteVar, SocketEvent
 from reflex.event import Event, EventHandler
 from reflex.event import Event, EventHandler
-from reflex.state import MutableProxy, State
-from reflex.utils import format
+from reflex.state import (
+    ImmutableStateError,
+    LockExpiredError,
+    MutableProxy,
+    State,
+    StateManager,
+    StateManagerMemory,
+    StateManagerRedis,
+    StateProxy,
+    StateUpdate,
+)
+from reflex.utils import format, prerequisites
 from reflex.vars import BaseVar, ComputedVar
 from reflex.vars import BaseVar, ComputedVar
 
 
+from .states import GenState
+
+CI = bool(os.environ.get("CI", False))
+LOCK_EXPIRATION = 2000 if CI else 100
+LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
+
 
 
 class Object(Base):
 class Object(Base):
     """A test object fixture."""
     """A test object fixture."""
@@ -704,13 +724,9 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_process_event_generator(gen_state):
-    """Test event handlers that generate multiple updates.
-
-    Args:
-        gen_state: A state.
-    """
-    gen_state = gen_state()
+async def test_process_event_generator():
+    """Test event handlers that generate multiple updates."""
+    gen_state = GenState()  # type: ignore
     event = Event(
     event = Event(
         token="t",
         token="t",
         name="go",
         name="go",
@@ -1402,6 +1418,396 @@ def test_state_with_invalid_yield():
     )
     )
 
 
 
 
+@pytest.fixture(scope="function", params=["in_process", "redis"])
+def state_manager(request) -> Generator[StateManager, None, None]:
+    """Instance of state manager parametrized for redis and in-process.
+
+    Args:
+        request: pytest request object.
+
+    Yields:
+        A state manager instance
+    """
+    state_manager = StateManager.create(state=TestState)
+    if request.param == "redis":
+        if not isinstance(state_manager, StateManagerRedis):
+            pytest.skip("Test requires redis")
+    else:
+        # explicitly NOT using redis
+        state_manager = StateManagerMemory(state=TestState)
+        assert not state_manager._states_locks
+
+    yield state_manager
+
+    if isinstance(state_manager, StateManagerRedis):
+        asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
+
+
+@pytest.mark.asyncio
+async def test_state_manager_modify_state(state_manager: StateManager, token: str):
+    """Test that the state manager can modify a state exclusively.
+
+    Args:
+        state_manager: A state manager instance.
+        token: A token.
+    """
+    async with state_manager.modify_state(token):
+        if isinstance(state_manager, StateManagerRedis):
+            assert await state_manager.redis.get(f"{token}_lock")
+        elif isinstance(state_manager, StateManagerMemory):
+            assert token in state_manager._states_locks
+            assert state_manager._states_locks[token].locked()
+    # lock should be dropped after exiting the context
+    if isinstance(state_manager, StateManagerRedis):
+        assert (await state_manager.redis.get(f"{token}_lock")) is None
+    elif isinstance(state_manager, StateManagerMemory):
+        assert not state_manager._states_locks[token].locked()
+
+        # separate instances should NOT share locks
+        sm2 = StateManagerMemory(state=TestState)
+        assert sm2._state_manager_lock is state_manager._state_manager_lock
+        assert not sm2._states_locks
+        if state_manager._states_locks:
+            assert sm2._states_locks != state_manager._states_locks
+
+
+@pytest.mark.asyncio
+async def test_state_manager_contend(state_manager: StateManager, token: str):
+    """Multiple coroutines attempting to access the same state.
+
+    Args:
+        state_manager: A state manager instance.
+        token: A token.
+    """
+    n_coroutines = 10
+    exp_num1 = 10
+
+    async with state_manager.modify_state(token) as state:
+        state.num1 = 0
+
+    async def _coro():
+        async with state_manager.modify_state(token) as state:
+            await asyncio.sleep(0.01)
+            state.num1 += 1
+
+    tasks = [asyncio.create_task(_coro()) for _ in range(n_coroutines)]
+
+    for f in asyncio.as_completed(tasks):
+        await f
+
+    assert (await state_manager.get_state(token)).num1 == exp_num1
+
+    if isinstance(state_manager, StateManagerRedis):
+        assert (await state_manager.redis.get(f"{token}_lock")) is None
+    elif isinstance(state_manager, StateManagerMemory):
+        assert token in state_manager._states_locks
+        assert not state_manager._states_locks[token].locked()
+
+
+@pytest.fixture(scope="function")
+def state_manager_redis() -> Generator[StateManager, None, None]:
+    """Instance of state manager for redis only.
+
+    Yields:
+        A state manager instance
+    """
+    state_manager = StateManager.create(TestState)
+
+    if not isinstance(state_manager, StateManagerRedis):
+        pytest.skip("Test requires redis")
+
+    yield state_manager
+
+    asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
+
+
+@pytest.mark.asyncio
+async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str):
+    """Test that the state manager lock expires and raises exception exiting context.
+
+    Args:
+        state_manager_redis: A state manager instance.
+        token: A token.
+    """
+    state_manager_redis.lock_expiration = LOCK_EXPIRATION
+
+    async with state_manager_redis.modify_state(token):
+        await asyncio.sleep(0.01)
+
+    with pytest.raises(LockExpiredError):
+        async with state_manager_redis.modify_state(token):
+            await asyncio.sleep(LOCK_EXPIRE_SLEEP)
+
+
+@pytest.mark.asyncio
+async def test_state_manager_lock_expire_contend(
+    state_manager_redis: StateManager, token: str
+):
+    """Test that the state manager lock expires and queued waiters proceed.
+
+    Args:
+        state_manager_redis: A state manager instance.
+        token: A token.
+    """
+    exp_num1 = 4252
+    unexp_num1 = 666
+
+    state_manager_redis.lock_expiration = LOCK_EXPIRATION
+
+    order = []
+
+    async def _coro_blocker():
+        async with state_manager_redis.modify_state(token) as state:
+            order.append("blocker")
+            await asyncio.sleep(LOCK_EXPIRE_SLEEP)
+            state.num1 = unexp_num1
+
+    async def _coro_waiter():
+        while "blocker" not in order:
+            await asyncio.sleep(0.005)
+        async with state_manager_redis.modify_state(token) as state:
+            order.append("waiter")
+            assert state.num1 != unexp_num1
+            state.num1 = exp_num1
+
+    tasks = [
+        asyncio.create_task(_coro_blocker()),
+        asyncio.create_task(_coro_waiter()),
+    ]
+    with pytest.raises(LockExpiredError):
+        await tasks[0]
+    await tasks[1]
+
+    assert order == ["blocker", "waiter"]
+    assert (await state_manager_redis.get_state(token)).num1 == exp_num1
+
+
+@pytest.fixture(scope="function")
+def mock_app(monkeypatch, app: rx.App, state_manager: StateManager) -> rx.App:
+    """Mock app fixture.
+
+    Args:
+        monkeypatch: Pytest monkeypatch object.
+        app: An app.
+        state_manager: A state manager.
+
+    Returns:
+        The app, after mocking out prerequisites.get_app()
+    """
+    app_module = Mock()
+    setattr(app_module, APP_VAR, app)
+    app.state = TestState
+    app.state_manager = state_manager
+    assert app.event_namespace is not None
+    app.event_namespace.emit = AsyncMock()
+    monkeypatch.setattr(prerequisites, "get_app", lambda: app_module)
+    return app
+
+
+@pytest.mark.asyncio
+async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
+    """Test that the state proxy works.
+
+    Args:
+        grandchild_state: A grandchild state.
+        mock_app: An app that will be returned by `get_app()`
+    """
+    child_state = grandchild_state.parent_state
+    assert child_state is not None
+    parent_state = child_state.parent_state
+    assert parent_state is not None
+    if isinstance(mock_app.state_manager, StateManagerMemory):
+        mock_app.state_manager.states[parent_state.get_token()] = parent_state
+
+    sp = StateProxy(grandchild_state)
+    assert sp.__wrapped__ == grandchild_state
+    assert sp._self_substate_path == grandchild_state.get_full_name().split(".")
+    assert sp._self_app is mock_app
+    assert not sp._self_mutable
+    assert sp._self_actx is None
+
+    # cannot use normal contextmanager protocol
+    with pytest.raises(TypeError), sp:
+        pass
+
+    with pytest.raises(ImmutableStateError):
+        # cannot directly modify state proxy outside of async context
+        sp.value2 = 16
+
+    async with sp:
+        assert sp._self_actx is not None
+        assert sp._self_mutable  # proxy is mutable inside context
+        if isinstance(mock_app.state_manager, StateManagerMemory):
+            # For in-process store, only one instance of the state exists
+            assert sp.__wrapped__ is grandchild_state
+        else:
+            # When redis is used, a new+updated instance is assigned to the proxy
+            assert sp.__wrapped__ is not grandchild_state
+        sp.value2 = 42
+    assert not sp._self_mutable  # proxy is not mutable after exiting context
+    assert sp._self_actx is None
+    assert sp.value2 == 42
+
+    # Get the state from the state manager directly and check that the value is updated
+    gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token())
+    if isinstance(mock_app.state_manager, StateManagerMemory):
+        # For in-process store, only one instance of the state exists
+        assert gotten_state is parent_state
+    else:
+        assert gotten_state is not parent_state
+    gotten_grandchild_state = gotten_state.get_substate(sp._self_substate_path)
+    assert gotten_grandchild_state is not None
+    assert gotten_grandchild_state.value2 == 42
+
+    # ensure state update was emitted
+    assert mock_app.event_namespace is not None
+    mock_app.event_namespace.emit.assert_called_once()
+    mcall = mock_app.event_namespace.emit.mock_calls[0]
+    assert mcall.args[0] == str(SocketEvent.EVENT)
+    assert json.loads(mcall.args[1]) == StateUpdate(
+        delta={
+            parent_state.get_full_name(): {
+                "upper": "",
+                "sum": 3.14,
+            },
+            grandchild_state.get_full_name(): {
+                "value2": 42,
+            },
+        }
+    )
+    assert mcall.kwargs["to"] == grandchild_state.get_sid()
+
+
+class BackgroundTaskState(State):
+    """A state with a background task."""
+
+    order: List[str] = []
+    dict_list: Dict[str, List[int]] = {"foo": []}
+
+    @rx.background
+    async def background_task(self):
+        """A background task that updates the state."""
+        async with self:
+            assert not self.order
+            self.order.append("background_task:start")
+
+        assert isinstance(self, StateProxy)
+        with pytest.raises(ImmutableStateError):
+            self.order.append("bad idea")
+
+        with pytest.raises(ImmutableStateError):
+            # Even nested access to mutables raises an exception.
+            self.dict_list["foo"].append(42)
+
+        # wait for some other event to happen
+        while len(self.order) == 1:
+            await asyncio.sleep(0.01)
+            async with self:
+                pass  # update proxy instance
+
+        async with self:
+            self.order.append("background_task:stop")
+
+    @rx.background
+    async def background_task_generator(self):
+        """A background task generator that does nothing.
+
+        Yields:
+            None
+        """
+        yield
+
+    def other(self):
+        """Some other event that updates the state."""
+        self.order.append("other")
+
+    async def bad_chain1(self):
+        """Test that a background task cannot be chained."""
+        await self.background_task()
+
+    async def bad_chain2(self):
+        """Test that a background task generator cannot be chained."""
+        async for _foo in self.background_task_generator():
+            pass
+
+
+@pytest.mark.asyncio
+async def test_background_task_no_block(mock_app: rx.App, token: str):
+    """Test that a background task does not block other events.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+    router_data = {"query": {}}
+    mock_app.state_manager.state = mock_app.state = BackgroundTaskState
+    async for update in rx.app.process(  # type: ignore
+        mock_app,
+        Event(
+            token=token,
+            name=f"{BackgroundTaskState.get_name()}.background_task",
+            router_data=router_data,
+            payload={},
+        ),
+        sid="",
+        headers={},
+        client_ip="",
+    ):
+        # background task returns empty update immediately
+        assert update == StateUpdate()
+    assert len(mock_app.background_tasks) == 1
+
+    # wait for the coroutine to start
+    await asyncio.sleep(0.5 if CI else 0.1)
+    assert len(mock_app.background_tasks) == 1
+
+    # Process another normal event
+    async for update in rx.app.process(  # type: ignore
+        mock_app,
+        Event(
+            token=token,
+            name=f"{BackgroundTaskState.get_name()}.other",
+            router_data=router_data,
+            payload={},
+        ),
+        sid="",
+        headers={},
+        client_ip="",
+    ):
+        # other task returns delta
+        assert update == StateUpdate(
+            delta={
+                BackgroundTaskState.get_name(): {
+                    "order": [
+                        "background_task:start",
+                        "other",
+                    ],
+                }
+            }
+        )
+
+    # Explicit wait for background tasks
+    for task in tuple(mock_app.background_tasks):
+        await task
+    assert not mock_app.background_tasks
+
+    assert (await mock_app.state_manager.get_state(token)).order == [
+        "background_task:start",
+        "other",
+        "background_task:stop",
+    ]
+
+
+@pytest.mark.asyncio
+async def test_background_task_no_chain():
+    """Test that a background task cannot be chained."""
+    bts = BackgroundTaskState()
+    with pytest.raises(RuntimeError):
+        await bts.bad_chain1()
+    with pytest.raises(RuntimeError):
+        await bts.bad_chain2()
+
+
 def test_mutable_list(mutable_state):
 def test_mutable_list(mutable_state):
     """Test that mutable lists are tracked correctly.
     """Test that mutable lists are tracked correctly.