Переглянути джерело

[REF-1885] Shard Substates when serializing to Redis (#2574)

* Move sharding internal to StateManager

Avoid leaking sharding implementation details all over the State class and
breaking the API

* WiP StateManager based sharding

* Copy the state __dict__ when serializing to avoid breaking the instance

* State tests need to pass the correct substate token for redis

* state: when getting parent_state, set top_level=False

ensure that we don't end up with a broken tree

* test_app: get tests passing with redis by passing the correct token

refactor upload tests to suck less

* test_client_storage: look up substate key

* state.py: pass static checks

* test_dynamic_routes: working with redis state shard

* Update the remaining AppHarness tests to pass {token}_{state.get_full_name()}

* test_app: pass all tokens with state suffix

* StateManagerRedis: clean up commentary
Masen Furer 1 рік тому
батько
коміт
756bf9b0f4

+ 1 - 1
integration/test_client_storage.py

@@ -449,7 +449,7 @@ async def test_client_side_state(
     assert l1s.text == "l1s value"
 
     # reset the backend state to force refresh from client storage
-    async with client_side.modify_state(token) as state:
+    async with client_side.modify_state(f"{token}_state.client_side_state") as state:
         state.reset()
     driver.refresh()
 

+ 8 - 3
integration/test_dynamic_routes.py

@@ -85,6 +85,7 @@ def dynamic_route(
     """
     with app_harness_env.create(
         root=tmp_path_factory.mktemp(f"dynamic_route"),
+        app_name=f"dynamicroute_{app_harness_env.__name__.lower()}",
         app_source=DynamicRoute,  # type: ignore
     ) as harness:
         yield harness
@@ -146,7 +147,7 @@ def poll_for_order(
 
     async def _poll_for_order(exp_order: list[str]):
         async def _backend_state():
-            return await dynamic_route.get_state(token)
+            return await dynamic_route.get_state(f"{token}_state.dynamic_state")
 
         async def _check():
             return (await _backend_state()).substates[
@@ -194,7 +195,9 @@ async def test_on_load_navigate(
         assert link
         assert page_id_input
 
-        assert dynamic_route.poll_for_value(page_id_input) == str(ix)
+        assert dynamic_route.poll_for_value(
+            page_id_input, exp_not_equal=str(ix - 1)
+        ) == str(ix)
         assert dynamic_route.poll_for_value(raw_path_input) == f"/page/{ix}/"
     await poll_for_order(exp_order)
 
@@ -220,7 +223,9 @@ async def test_on_load_navigate(
     with poll_for_navigation(driver):
         driver.get(f"{driver.current_url}?foo=bar")
     await poll_for_order(exp_order)
-    assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar"
+    assert (
+        await dynamic_route.get_state(f"{token}_state.dynamic_state")
+    ).router.page.params["foo"] == "bar"
 
     # hit a 404 and ensure we still hydrate
     exp_order += ["/404-no page id"]

+ 1 - 1
integration/test_event_actions.py

@@ -207,7 +207,7 @@ def poll_for_order(
 
     async def _poll_for_order(exp_order: list[str]):
         async def _backend_state():
-            return await event_action.get_state(token)
+            return await event_action.get_state(f"{token}_state.event_action_state")
 
         async def _check():
             return (await _backend_state()).substates[

+ 1 - 1
integration/test_event_chain.py

@@ -298,7 +298,7 @@ def assert_token(event_chain: AppHarness, driver: WebDriver) -> str:
     token = event_chain.poll_for_value(token_input)
     assert token is not None
 
-    return token
+    return f"{token}_state.state"
 
 
 @pytest.mark.parametrize(

+ 5 - 1
integration/test_form_submit.py

@@ -221,7 +221,11 @@ async def test_submit(driver, form_submit: AppHarness):
     submit_input.click()
 
     async def get_form_data():
-        return (await form_submit.get_state(token)).substates["form_state"].form_data
+        return (
+            (await form_submit.get_state(f"{token}_state.form_state"))
+            .substates["form_state"]
+            .form_data
+        )
 
     # wait for the form data to arrive at the backend
     form_data = await AppHarness._poll_for_async(get_form_data)

+ 9 - 11
integration/test_input.py

@@ -76,6 +76,10 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     token = fully_controlled_input.poll_for_value(token_input)
     assert token
 
+    async def get_state_text():
+        state = await fully_controlled_input.get_state(f"{token}_state.state")
+        return state.substates["state"].text
+
     # find the input and wait for it to have the initial state value
     debounce_input = driver.find_element(By.ID, "debounce_input_input")
     value_input = driver.find_element(By.ID, "value_input")
@@ -95,16 +99,14 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("foo")
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "ifoonitial"
-    assert (await fully_controlled_input.get_state(token)).substates[
-        "state"
-    ].text == "ifoonitial"
+    assert await get_state_text() == "ifoonitial"
     assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
     assert fully_controlled_input.poll_for_value(plain_value_input) == "ifoonitial"
 
     # clear the input on the backend
-    async with fully_controlled_input.modify_state(token) as state:
+    async with fully_controlled_input.modify_state(f"{token}_state.state") as state:
         state.substates["state"].text = ""
-    assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
+    assert await get_state_text() == ""
     assert (
         fully_controlled_input.poll_for_value(
             debounce_input, exp_not_equal="ifoonitial"
@@ -116,9 +118,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("getting testing done")
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "getting testing done"
-    assert (await fully_controlled_input.get_state(token)).substates[
-        "state"
-    ].text == "getting testing done"
+    assert await get_state_text() == "getting testing done"
     assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
     assert (
         fully_controlled_input.poll_for_value(plain_value_input)
@@ -130,9 +130,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "overwrite the state"
     assert on_change_input.get_attribute("value") == "overwrite the state"
-    assert (await fully_controlled_input.get_state(token)).substates[
-        "state"
-    ].text == "overwrite the state"
+    assert await get_state_text() == "overwrite the state"
     assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
     assert (
         fully_controlled_input.poll_for_value(plain_value_input)

+ 15 - 4
integration/test_upload.py

@@ -171,6 +171,7 @@ async def test_upload_file(
     # wait for the backend connection to send the token
     token = upload_file.poll_for_value(token_input)
     assert token is not None
+    substate_token = f"{token}_state.upload_state"
 
     suffix = "_secondary" if secondary else ""
 
@@ -191,7 +192,11 @@ async def test_upload_file(
 
     # look up the backend state and assert on uploaded contents
     async def get_file_data():
-        return (await upload_file.get_state(token)).substates["upload_state"]._file_data
+        return (
+            (await upload_file.get_state(substate_token))
+            .substates["upload_state"]
+            ._file_data
+        )
 
     file_data = await AppHarness._poll_for_async(get_file_data)
     assert isinstance(file_data, dict)
@@ -201,7 +206,7 @@ async def test_upload_file(
     selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
     assert selected_files.text == exp_name
 
-    state = await upload_file.get_state(token)
+    state = await upload_file.get_state(substate_token)
     if secondary:
         # only the secondary form tracks progress and chain events
         assert state.substates["upload_state"].event_order.count("upload_progress") == 1
@@ -223,6 +228,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
     # wait for the backend connection to send the token
     token = upload_file.poll_for_value(token_input)
     assert token is not None
+    substate_token = f"{token}_state.upload_state"
 
     upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
     assert upload_box
@@ -250,7 +256,11 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
 
     # look up the backend state and assert on uploaded contents
     async def get_file_data():
-        return (await upload_file.get_state(token)).substates["upload_state"]._file_data
+        return (
+            (await upload_file.get_state(substate_token))
+            .substates["upload_state"]
+            ._file_data
+        )
 
     file_data = await AppHarness._poll_for_async(get_file_data)
     assert isinstance(file_data, dict)
@@ -330,6 +340,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
     # wait for the backend connection to send the token
     token = upload_file.poll_for_value(token_input)
     assert token is not None
+    substate_token = f"{token}_state.upload_state"
 
     upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1]
     upload_button = driver.find_element(By.ID, f"upload_button_secondary")
@@ -347,7 +358,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
     cancel_button.click()
 
     # look up the backend state and assert on progress
-    state = await upload_file.get_state(token)
+    state = await upload_file.get_state(substate_token)
     assert state.substates["upload_state"].progress_dicts
     assert exp_name not in state.substates["upload_state"]._file_data
 

+ 4 - 3
reflex/app.py

@@ -926,7 +926,7 @@ async def process(
         }
     )
     # Get the state for the session exclusively.
-    async with app.state_manager.modify_state(event.token) as state:
+    async with app.state_manager.modify_state(event.substate_token) as state:
         # re-assign only when the value is different
         if state.router_data != router_data:
             # assignment will recurse into substates and force recalculation of
@@ -1002,7 +1002,8 @@ def upload(app: App):
             )
 
         # Get the state for the session.
-        state = await app.state_manager.get_state(token)
+        substate_token = token + "_" + handler.rpartition(".")[0]
+        state = await app.state_manager.get_state(substate_token)
 
         # get the current session ID
         # get the current state(parent state/substate)
@@ -1049,7 +1050,7 @@ def upload(app: App):
                 Each state update as JSON followed by a new line.
             """
             # Process the event.
-            async with app.state_manager.modify_state(token) as state:
+            async with app.state_manager.modify_state(event.substate_token) as state:
                 async for update in state._process(event):
                     # Postprocess the event.
                     update = await app.postprocess(state, event, update)

+ 10 - 0
reflex/event.py

@@ -41,6 +41,16 @@ class Event(Base):
     # The event payload.
     payload: Dict[str, Any] = {}
 
+    @property
+    def substate_token(self) -> str:
+        """Get the substate token for the event.
+
+        Returns:
+            The substate token.
+        """
+        substate = self.name.rpartition(".")[0]
+        return f"{self.token}_{substate}"
+
 
 BACKGROUND_TASK_MARKER = "_reflex_background_task"
 

+ 157 - 13
reflex/state.py

@@ -213,21 +213,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # The router data for the current page
     router: RouterData = RouterData()
 
-    def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
+    def __init__(
+        self,
+        *args,
+        parent_state: BaseState | None = None,
+        init_substates: bool = True,
+        **kwargs,
+    ):
         """Initialize the state.
 
         Args:
             *args: The args to pass to the Pydantic init method.
             parent_state: The parent state.
+            init_substates: Whether to initialize the substates in this instance.
             **kwargs: The kwargs to pass to the Pydantic init method.
 
         """
         kwargs["parent_state"] = parent_state
         super().__init__(*args, **kwargs)
 
-        # Setup the substates.
-        for substate in self.get_substates():
-            self.substates[substate.get_name()] = substate(parent_state=self)
+        # Setup the substates (for memory state manager only).
+        if init_substates:
+            for substate in self.get_substates():
+                self.substates[substate.get_name()] = substate(parent_state=self)
         # Convert the event handlers to functions.
         self._init_event_handlers()
 
@@ -1005,7 +1013,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         for substate in self.substates.values():
             substate._reset_client_storage()
 
-    def get_substate(self, path: Sequence[str]) -> BaseState | None:
+    def get_substate(self, path: Sequence[str]) -> BaseState:
         """Get the substate.
 
         Args:
@@ -1260,6 +1268,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Recursively find the substate deltas.
         substates = self.substates
         for substate in self.dirty_substates.union(self._always_dirty_substates):
+            if substate not in substates:
+                continue  # substate not loaded at this time, no delta
             delta.update(substates[substate].get_delta())
 
         # Format the delta.
@@ -1287,6 +1297,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         for var in self.dirty_vars:
             for substate_name in self._substate_var_dependencies[var]:
                 self.dirty_substates.add(substate_name)
+                if substate_name not in substates:
+                    continue
                 substate = substates[substate_name]
                 substate.dirty_vars.add(var)
                 substate._mark_dirty()
@@ -1295,6 +1307,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """Reset the dirty vars."""
         # Recursively clean the substates.
         for substate in self.dirty_substates:
+            if substate not in self.substates:
+                continue
             self.substates[substate]._clean()
 
         # Clean this state.
@@ -1380,6 +1394,24 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         pass
 
+    def __getstate__(self):
+        """Get the state for redis serialization.
+
+        This method is called by cloudpickle to serialize the object.
+
+        It explicitly removes parent_state and substates because those are serialized separately
+        by the StateManagerRedis to allow for better horizontal scaling as state size increases.
+
+        Returns:
+            The state dict for serialization.
+        """
+        state = super().__getstate__()
+        # Never serialize parent_state or substates
+        state["__dict__"] = state["__dict__"].copy()
+        state["__dict__"]["parent_state"] = None
+        state["__dict__"]["substates"] = {}
+        return state
+
 
 class State(BaseState):
     """The app Base State."""
@@ -1479,6 +1511,8 @@ class StateProxy(wrapt.ObjectProxy):
         """
         self._self_actx = self._self_app.modify_state(
             self.__wrapped__.router.session.client_token
+            + "_"
+            + ".".join(self._self_substate_path)
         )
         mutable_state = await self._self_actx.__aenter__()
         super().__setattr__(
@@ -1675,6 +1709,8 @@ class StateManagerMemory(StateManager):
         Returns:
             The state for the token.
         """
+        # Memory state manager ignores the substate suffix and always returns the top-level state.
+        token = token.partition("_")[0]
         if token not in self.states:
             self.states[token] = self.state()
         return self.states[token]
@@ -1698,6 +1734,8 @@ class StateManagerMemory(StateManager):
         Yields:
             The state for the token.
         """
+        # Memory state manager ignores the substate suffix and always returns the top-level state.
+        token = token.partition("_")[0]
         if token not in self._states_locks:
             async with self._state_manager_lock:
                 if token not in self._states_locks:
@@ -1737,23 +1775,104 @@ class StateManagerRedis(StateManager):
         b"evicted",
     }
 
-    async def get_state(self, token: str) -> BaseState:
+    async def get_state(
+        self,
+        token: str,
+        top_level: bool = True,
+        get_substates: bool = True,
+        parent_state: BaseState | None = None,
+    ) -> BaseState:
         """Get the state for a token.
 
         Args:
             token: The token to get the state for.
+            top_level: If true, return an instance of the top-level state.
+            get_substates: If true, also retrieve substates
+            parent_state: If provided, use this parent_state instead of getting it from redis.
 
         Returns:
             The state for the token.
+
+        Raises:
+            RuntimeError: when the state_cls is not specified in the token
         """
+        # Split the actual token from the fully qualified substate name.
+        client_token, _, state_path = token.partition("_")
+        if state_path:
+            # Get the State class associated with the given path.
+            state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
+        else:
+            raise RuntimeError(
+                "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
+            )
+
+        # Fetch the serialized substate from redis.
         redis_state = await self.redis.get(token)
-        if redis_state is None:
-            await self.set_state(token, self.state())
-            return await self.get_state(token)
-        return cloudpickle.loads(redis_state)
+
+        if redis_state is not None:
+            # Deserialize the substate.
+            state = cloudpickle.loads(redis_state)
+
+            # Populate parent and substates if requested.
+            if parent_state is None:
+                # Retrieve the parent state from redis.
+                parent_state_name = state_path.rpartition(".")[0]
+                if parent_state_name:
+                    parent_state_key = token.rpartition(".")[0]
+                    parent_state = await self.get_state(
+                        parent_state_key, top_level=False, get_substates=False
+                    )
+            # Set up Bidirectional linkage between this state and its parent.
+            if parent_state is not None:
+                parent_state.substates[state.get_name()] = state
+                state.parent_state = parent_state
+            if get_substates:
+                # Retrieve all substates from redis.
+                for substate_cls in state_cls.get_substates():
+                    substate_name = substate_cls.get_name()
+                    substate_key = token + "." + substate_name
+                    state.substates[substate_name] = await self.get_state(
+                        substate_key, top_level=False, parent_state=state
+                    )
+            # To retain compatibility with previous implementation, by default, we return
+            # the top-level state by chasing `parent_state` pointers up the tree.
+            if top_level:
+                while type(state) != self.state and state.parent_state is not None:
+                    state = state.parent_state
+            return state
+
+        # Key didn't exist so we have to create a new entry for this token.
+        if parent_state is None:
+            parent_state_name = state_path.rpartition(".")[0]
+            if parent_state_name:
+                # Retrieve the parent state to populate event handlers onto this substate.
+                parent_state_key = client_token + "_" + parent_state_name
+                parent_state = await self.get_state(
+                    parent_state_key, top_level=False, get_substates=False
+                )
+        # Persist the new state class to redis.
+        await self.set_state(
+            token,
+            state_cls(
+                parent_state=parent_state,
+                init_substates=False,
+            ),
+        )
+        # After creating the state key, recursively call `get_state` to populate substates.
+        return await self.get_state(
+            token,
+            top_level=top_level,
+            get_substates=get_substates,
+            parent_state=parent_state,
+        )
 
     async def set_state(
-        self, token: str, state: BaseState, lock_id: bytes | None = None
+        self,
+        token: str,
+        state: BaseState,
+        lock_id: bytes | None = None,
+        set_substates: bool = True,
+        set_parent_state: bool = True,
     ):
         """Set the state for a token.
 
@@ -1761,11 +1880,13 @@ class StateManagerRedis(StateManager):
             token: The token to set the state for.
             state: The state to set.
             lock_id: If provided, the lock_key must be set to this value to set the state.
+            set_substates: If True, write substates to redis
+            set_parent_state: If True, write parent state to redis
 
         Raises:
             LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
         """
-        # check that we're holding the lock
+        # Check that we're holding the lock.
         if (
             lock_id is not None
             and await self.redis.get(self._lock_key(token)) != lock_id
@@ -1775,6 +1896,27 @@ class StateManagerRedis(StateManager):
                 f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
                 "or use `@rx.background` decorator for long-running tasks."
             )
+        # Find the substate associated with the token.
+        state_path = token.partition("_")[2]
+        if state_path and state.get_full_name() != state_path:
+            state = state.get_substate(tuple(state_path.split(".")))
+        # Persist the parent state separately, if requested.
+        if state.parent_state is not None and set_parent_state:
+            parent_state_key = token.rpartition(".")[0]
+            await self.set_state(
+                parent_state_key,
+                state.parent_state,
+                lock_id=lock_id,
+                set_substates=False,
+            )
+        # Persist the substates separately, if requested.
+        if set_substates:
+            for substate_name, substate in state.substates.items():
+                substate_key = token + "." + substate_name
+                await self.set_state(
+                    substate_key, substate, lock_id=lock_id, set_parent_state=False
+                )
+        # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
         await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
 
     @contextlib.asynccontextmanager
@@ -1802,7 +1944,9 @@ class StateManagerRedis(StateManager):
         Returns:
             The redis lock key for the token.
         """
-        return f"{token}_lock".encode()
+        # All substates share the same lock domain, so ignore any substate path suffix.
+        client_token = token.partition("_")[0]
+        return f"{client_token}_lock".encode()
 
     async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
         """Try to get a redis lock for a token.

+ 1 - 0
reflex/testing.py

@@ -220,6 +220,7 @@ class AppHarness:
             reflex.config.get_config(reload=True)
             # reset rx.State subclasses
             State.class_subclasses.clear()
+            State.get_class_substate.cache_clear()
             # Ensure the AppHarness test does not skip State assignment due to running via pytest
             os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
             # self.app_module.app.

+ 19 - 22
tests/test_app.py

@@ -340,7 +340,7 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str):
     assert app.state == test_state
 
     # Get a state for a given token.
-    state = await app.state_manager.get_state(token)
+    state = await app.state_manager.get_state(f"{token}_{test_state.get_full_name()}")
     assert isinstance(state, test_state)
     assert state.var == 0  # type: ignore
 
@@ -358,8 +358,8 @@ async def test_set_and_get_state(test_state):
     app = App(state=test_state)
 
     # Create two tokens.
-    token1 = str(uuid.uuid4())
-    token2 = str(uuid.uuid4())
+    token1 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}"
+    token2 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}"
 
     # Get the default state for each token.
     state1 = await app.state_manager.get_state(token1)
@@ -744,18 +744,18 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
     # The App state must be the "root" of the state tree
     app = App(state=State)
     app.event_namespace.emit = AsyncMock()  # type: ignore
-    current_state = await app.state_manager.get_state(token)
+    substate_token = f"{token}_{state.get_full_name()}"
+    current_state = await app.state_manager.get_state(substate_token)
     data = b"This is binary data"
 
     # Create a binary IO object and write data to it
     bio = io.BytesIO()
     bio.write(data)
 
-    state_name = state.get_full_name().partition(".")[2] or state.get_name()
     request_mock = unittest.mock.Mock()
     request_mock.headers = {
         "reflex-client-token": token,
-        "reflex-event-handler": f"state.{state_name}.multi_handle_upload",
+        "reflex-event-handler": f"{state.get_full_name()}.multi_handle_upload",
     }
 
     file1 = UploadFile(
@@ -774,7 +774,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
             == StateUpdate(delta=delta, events=[], final=True).json() + "\n"
         )
 
-    current_state = await app.state_manager.get_state(token)
+    current_state = await app.state_manager.get_state(substate_token)
     state_dict = current_state.dict()[state.get_full_name()]
     assert state_dict["img_list"] == [
         "image1.jpg",
@@ -799,14 +799,12 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
         token: a Token.
     """
     state._tmp_path = tmp_path
-    # The App state must be the "root" of the state tree
-    app = App(state=state if state is FileUploadState else FileStateBase1)
+    app = App(state=State)
 
-    state_name = state.get_full_name().partition(".")[2] or state.get_name()
     request_mock = unittest.mock.Mock()
     request_mock.headers = {
         "reflex-client-token": token,
-        "reflex-event-handler": f"{state_name}.handle_upload2",
+        "reflex-event-handler": f"{state.get_full_name()}.handle_upload2",
     }
     file_mock = unittest.mock.Mock(filename="image1.jpg")
     fn = upload(app)
@@ -814,7 +812,7 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
         await fn(request_mock, [file_mock])
     assert (
         err.value.args[0]
-        == f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
+        == f"`{state.get_full_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
     )
 
     if isinstance(app.state_manager, StateManagerRedis):
@@ -835,14 +833,12 @@ async def test_upload_file_background(state, tmp_path, token):
         token: a Token.
     """
     state._tmp_path = tmp_path
-    # The App state must be the "root" of the state tree
-    app = App(state=state if state is FileUploadState else FileStateBase1)
+    app = App(state=State)
 
-    state_name = state.get_full_name().partition(".")[2] or state.get_name()
     request_mock = unittest.mock.Mock()
     request_mock.headers = {
         "reflex-client-token": token,
-        "reflex-event-handler": f"{state_name}.bg_upload",
+        "reflex-event-handler": f"{state.get_full_name()}.bg_upload",
     }
     file_mock = unittest.mock.Mock(filename="image1.jpg")
     fn = upload(app)
@@ -850,7 +846,7 @@ async def test_upload_file_background(state, tmp_path, token):
         await fn(request_mock, [file_mock])
     assert (
         err.value.args[0]
-        == f"@rx.background is not supported for upload handler `{state_name}.bg_upload`."
+        == f"@rx.background is not supported for upload handler `{state.get_full_name()}.bg_upload`."
     )
 
     if isinstance(app.state_manager, StateManagerRedis):
@@ -932,9 +928,10 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     }
     assert constants.ROUTER in app.state()._computed_var_dependencies
 
+    substate_token = f"{token}_{DynamicState.get_full_name()}"
     sid = "mock_sid"
     client_ip = "127.0.0.1"
-    state = await app.state_manager.get_state(token)
+    state = await app.state_manager.get_state(substate_token)
     assert state.dynamic == ""
     exp_vals = ["foo", "foobar", "baz"]
 
@@ -1004,7 +1001,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         if isinstance(app.state_manager, StateManagerRedis):
             # When redis is used, the state is not updated until the processing is complete
-            state = await app.state_manager.get_state(token)
+            state = await app.state_manager.get_state(substate_token)
             assert state.dynamic == prev_exp_val
 
         # complete the processing
@@ -1012,7 +1009,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             await process_coro.__anext__()
 
         # check that router data was written to the state_manager store
-        state = await app.state_manager.get_state(token)
+        state = await app.state_manager.get_state(substate_token)
         assert state.dynamic == exp_val
 
         process_coro = process(
@@ -1087,7 +1084,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             await process_coro.__anext__()
 
         prev_exp_val = exp_val
-    state = await app.state_manager.get_state(token)
+    state = await app.state_manager.get_state(substate_token)
     assert state.loaded == len(exp_vals)
     assert state.counter == len(exp_vals)
     # print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
@@ -1124,7 +1121,7 @@ async def test_process_events(mocker, token: str):
     async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
         pass
 
-    assert (await app.state_manager.get_state(token)).value == 5
+    assert (await app.state_manager.get_state(event.substate_token)).value == 5
     assert app.postprocess.call_count == 6
 
     if isinstance(app.state_manager, StateManagerRedis):

+ 59 - 18
tests/test_state.py

@@ -38,8 +38,8 @@ from reflex.vars import BaseVar, ComputedVar
 from .states import GenState
 
 CI = bool(os.environ.get("CI", False))
-LOCK_EXPIRATION = 2000 if CI else 100
-LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
+LOCK_EXPIRATION = 2000 if CI else 300
+LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
 
 
 formatted_router = {
@@ -1432,15 +1432,32 @@ def state_manager(request) -> Generator[StateManager, None, None]:
         asyncio.get_event_loop().run_until_complete(state_manager.close())
 
 
+@pytest.fixture()
+def substate_token(state_manager, token):
+    """A token + substate name for looking up in state manager.
+
+    Args:
+        state_manager: A state manager instance.
+        token: A token.
+
+    Returns:
+        Token concatenated with the state_manager's state full_name.
+    """
+    return f"{token}_{state_manager.state.get_full_name()}"
+
+
 @pytest.mark.asyncio
-async def test_state_manager_modify_state(state_manager: StateManager, token: str):
+async def test_state_manager_modify_state(
+    state_manager: StateManager, token: str, substate_token: str
+):
     """Test that the state manager can modify a state exclusively.
 
     Args:
         state_manager: A state manager instance.
         token: A token.
+        substate_token: A token + substate name for looking up in state manager.
     """
-    async with state_manager.modify_state(token):
+    async with state_manager.modify_state(substate_token):
         if isinstance(state_manager, StateManagerRedis):
             assert await state_manager.redis.get(f"{token}_lock")
         elif isinstance(state_manager, StateManagerMemory):
@@ -1461,21 +1478,24 @@ async def test_state_manager_modify_state(state_manager: StateManager, token: st
 
 
 @pytest.mark.asyncio
-async def test_state_manager_contend(state_manager: StateManager, token: str):
+async def test_state_manager_contend(
+    state_manager: StateManager, token: str, substate_token: str
+):
     """Multiple coroutines attempting to access the same state.
 
     Args:
         state_manager: A state manager instance.
         token: A token.
+        substate_token: A token + substate name for looking up in state manager.
     """
     n_coroutines = 10
     exp_num1 = 10
 
-    async with state_manager.modify_state(token) as state:
+    async with state_manager.modify_state(substate_token) as state:
         state.num1 = 0
 
     async def _coro():
-        async with state_manager.modify_state(token) as state:
+        async with state_manager.modify_state(substate_token) as state:
             await asyncio.sleep(0.01)
             state.num1 += 1
 
@@ -1484,7 +1504,7 @@ async def test_state_manager_contend(state_manager: StateManager, token: str):
     for f in asyncio.as_completed(tasks):
         await f
 
-    assert (await state_manager.get_state(token)).num1 == exp_num1
+    assert (await state_manager.get_state(substate_token)).num1 == exp_num1
 
     if isinstance(state_manager, StateManagerRedis):
         assert (await state_manager.redis.get(f"{token}_lock")) is None
@@ -1510,33 +1530,51 @@ def state_manager_redis() -> Generator[StateManager, None, None]:
     asyncio.get_event_loop().run_until_complete(state_manager.close())
 
 
+@pytest.fixture()
+def substate_token_redis(state_manager_redis, token):
+    """A token + substate name for looking up in state manager.
+
+    Args:
+        state_manager_redis: A state manager instance.
+        token: A token.
+
+    Returns:
+        Token concatenated with the state_manager's state full_name.
+    """
+    return f"{token}_{state_manager_redis.state.get_full_name()}"
+
+
 @pytest.mark.asyncio
-async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str):
+async def test_state_manager_lock_expire(
+    state_manager_redis: StateManager, token: str, substate_token_redis: str
+):
     """Test that the state manager lock expires and raises exception exiting context.
 
     Args:
         state_manager_redis: A state manager instance.
         token: A token.
+        substate_token_redis: A token + substate name for looking up in state manager.
     """
     state_manager_redis.lock_expiration = LOCK_EXPIRATION
 
-    async with state_manager_redis.modify_state(token):
+    async with state_manager_redis.modify_state(substate_token_redis):
         await asyncio.sleep(0.01)
 
     with pytest.raises(LockExpiredError):
-        async with state_manager_redis.modify_state(token):
+        async with state_manager_redis.modify_state(substate_token_redis):
             await asyncio.sleep(LOCK_EXPIRE_SLEEP)
 
 
 @pytest.mark.asyncio
 async def test_state_manager_lock_expire_contend(
-    state_manager_redis: StateManager, token: str
+    state_manager_redis: StateManager, token: str, substate_token_redis: str
 ):
     """Test that the state manager lock expires and queued waiters proceed.
 
     Args:
         state_manager_redis: A state manager instance.
         token: A token.
+        substate_token_redis: A token + substate name for looking up in state manager.
     """
     exp_num1 = 4252
     unexp_num1 = 666
@@ -1546,7 +1584,7 @@ async def test_state_manager_lock_expire_contend(
     order = []
 
     async def _coro_blocker():
-        async with state_manager_redis.modify_state(token) as state:
+        async with state_manager_redis.modify_state(substate_token_redis) as state:
             order.append("blocker")
             await asyncio.sleep(LOCK_EXPIRE_SLEEP)
             state.num1 = unexp_num1
@@ -1554,7 +1592,7 @@ async def test_state_manager_lock_expire_contend(
     async def _coro_waiter():
         while "blocker" not in order:
             await asyncio.sleep(0.005)
-        async with state_manager_redis.modify_state(token) as state:
+        async with state_manager_redis.modify_state(substate_token_redis) as state:
             order.append("waiter")
             assert state.num1 != unexp_num1
             state.num1 = exp_num1
@@ -1568,7 +1606,7 @@ async def test_state_manager_lock_expire_contend(
     await tasks[1]
 
     assert order == ["blocker", "waiter"]
-    assert (await state_manager_redis.get_state(token)).num1 == exp_num1
+    assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
 
 
 @pytest.fixture(scope="function")
@@ -1643,7 +1681,8 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     assert sp.value2 == 42
 
     # Get the state from the state manager directly and check that the value is updated
-    gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token())
+    gc_token = f"{grandchild_state.get_token()}_{grandchild_state.get_full_name()}"
+    gotten_state = await mock_app.state_manager.get_state(gc_token)
     if isinstance(mock_app.state_manager, StateManagerMemory):
         # For in-process store, only one instance of the state exists
         assert gotten_state is parent_state
@@ -1836,7 +1875,8 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         "private",
     ]
 
-    assert (await mock_app.state_manager.get_state(token)).order == exp_order
+    substate_token = f"{token}_{BackgroundTaskState.get_name()}"
+    assert (await mock_app.state_manager.get_state(substate_token)).order == exp_order
 
     assert mock_app.event_namespace is not None
     emit_mock = mock_app.event_namespace.emit
@@ -1913,7 +1953,8 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
         await task
     assert not mock_app.background_tasks
 
-    assert (await mock_app.state_manager.get_state(token)).order == [
+    substate_token = f"{token}_{BackgroundTaskState.get_name()}"
+    assert (await mock_app.state_manager.get_state(substate_token)).order == [
         "reset",
     ]