1
0
Эх сурвалжийг харах

[REF-1988] API to Get instance of Arbitrary State class (#2678)

* WiP get_state

* Refactor get_state fast path

Rudimentary protection for state instance access from a background task
(StateProxy)

* retain dirty substate marking per `_mark_dirty` call to avoid test changes

* Find common ancestor by part instead of by character

Fix StateProxy for substates and parent_state attributes (have to handle in
__getattr__, not property)

Fix type annotation for `get_state`

* test_state: workflow test for `get_state` functionality

* Do not reset _always_dirty_substates when adding vars

Reset the substate tracking only when the class is instantiated.

* test_state_tree: test substate access in a larger state tree

Ensure that `get_state` returns the proper "branch" of the state tree depending
on what substate is requested.

* test_format: fixup broken tests from adding substates of TestState

* Fix flaky integration tests with more polling

* AppHarness: reset _always_dirty_substates on rx.State

* RuntimeError unless State is instantiated with _reflex_internal_init=True

Avoid user errors trying to directly instantiate State classes

* Helper functions for _substate_key and _split_substate_key

Unify the implementation of generating and decoding the token + state name
format used for redis state sharding.

* StateManagerRedis: use create_task in get_state and set_state

read and write substates concurrently (allow redis to shine)

* test_state_inheritance: use polling cuz life too short for flaky tests

kthnxbai :heart:

* Move _is_testing_env to reflex.utils.exec.is_testing_env

Reuse the code in app.py

* Break up `BaseState.get_state` and friends into separate methods

* Add test case for pre-fetching cached var dependency

* Move on_load_internal and update_vars_internal to substates

Avoid loading the entire state tree to process these common internal events. If
the state tree is very large, this allow page navigation to occur more quickly.

Pre-fetch substates that contain cached vars, as they may need to be recomputed
if certain vars change.

* Do not copy ROUTER_DATA into all substates.

This is a waste of time and memory, and can be handled via a special case in
__getattribute__

* Track whether State instance _was_touched

Avoid wasting time serializing states that have no modifications

* Do not persist states in `StateManagerRedis.get_state`

Wait until the state is actually modified, and then persist it as part of `set_state`.

Factor out common logic into helper methods for readability and to reduce
duplication of common logic.

To avoid having to recursively call `get_state`, which would require persisting
the instance and then getting it again, some of the initialization logic
regarding parent_state and substates is duplicated when creating a new
instance. This is for performance reasons.

* Remove stray print()

* context.js.jinja2: fix check for empty local storage / cookie vars

* Add comments for onLoadInternalEvent and initialEvents

* nit: typo

* split _get_was_touched into _update_was_touched

Improve clarity in cases where _get_was_touched was being called for its side
effects only.

* Remove extraneous information from incorrect State instantiation error

* Update missing redis exception message
Masen Furer 1 жил өмнө
parent
commit
deae662e2a

+ 2 - 2
integration/test_client_storage.py

@@ -518,8 +518,8 @@ async def test_client_side_state(
     set_sub("l6", "l6 value")
     set_sub("l6", "l6 value")
     l5 = driver.find_element(By.ID, "l5")
     l5 = driver.find_element(By.ID, "l5")
     l6 = driver.find_element(By.ID, "l6")
     l6 = driver.find_element(By.ID, "l6")
+    assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert l5.text == "l5 value"
     assert l5.text == "l5 value"
-    assert l6.text == "l6 value"
 
 
     # Switch back to main window.
     # Switch back to main window.
     driver.switch_to.window(main_tab)
     driver.switch_to.window(main_tab)
@@ -527,8 +527,8 @@ async def test_client_side_state(
     # The values should have updated automatically.
     # The values should have updated automatically.
     l5 = driver.find_element(By.ID, "l5")
     l5 = driver.find_element(By.ID, "l5")
     l6 = driver.find_element(By.ID, "l6")
     l6 = driver.find_element(By.ID, "l6")
+    assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert l5.text == "l5 value"
     assert l5.text == "l5 value"
-    assert l6.text == "l6 value"
 
 
     # clear the cookie jar and local storage, ensure state reset to default
     # clear the cookie jar and local storage, ensure state reset to default
     driver.delete_all_cookies()
     driver.delete_all_cookies()

+ 19 - 4
integration/test_state_inheritance.py

@@ -1,14 +1,29 @@
 """Test state inheritance."""
 """Test state inheritance."""
 
 
-import time
+from contextlib import suppress
 from typing import Generator
 from typing import Generator
 
 
 import pytest
 import pytest
+from selenium.common.exceptions import NoAlertPresentException
+from selenium.webdriver.common.alert import Alert
 from selenium.webdriver.common.by import By
 from selenium.webdriver.common.by import By
 
 
 from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
 from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
 
 
 
 
+def get_alert_or_none(driver: WebDriver) -> Alert | None:
+    """Switch to an alert if present.
+
+    Args:
+        driver: WebDriver instance.
+
+    Returns:
+        The alert if present, otherwise None.
+    """
+    with suppress(NoAlertPresentException):
+        return driver.switch_to.alert
+
+
 def raises_alert(driver: WebDriver, element: str) -> None:
 def raises_alert(driver: WebDriver, element: str) -> None:
     """Click an element and check that an alert is raised.
     """Click an element and check that an alert is raised.
 
 
@@ -18,8 +33,8 @@ def raises_alert(driver: WebDriver, element: str) -> None:
     """
     """
     btn = driver.find_element(By.ID, element)
     btn = driver.find_element(By.ID, element)
     btn.click()
     btn.click()
-    time.sleep(0.2)  # wait for the alert to appear
-    alert = driver.switch_to.alert
+    alert = AppHarness._poll_for(lambda: get_alert_or_none(driver))
+    assert isinstance(alert, Alert)
     assert alert.text == "clicked"
     assert alert.text == "clicked"
     alert.accept()
     alert.accept()
 
 
@@ -355,7 +370,7 @@ def test_state_inheritance(
     child3_other_mixin_btn = driver.find_element(By.ID, "child3-other-mixin-btn")
     child3_other_mixin_btn = driver.find_element(By.ID, "child3-other-mixin-btn")
     child3_other_mixin_btn.click()
     child3_other_mixin_btn.click()
     child2_other_mixin_value = state_inheritance.poll_for_content(
     child2_other_mixin_value = state_inheritance.poll_for_content(
-        child2_other_mixin, exp_not_equal="other_mixin"
+        child2_other_mixin, exp_not_equal="Child2.clicked.1"
     )
     )
     child2_computed_mixin_value = state_inheritance.poll_for_content(
     child2_computed_mixin_value = state_inheritance.poll_for_content(
         child2_computed_other_mixin, exp_not_equal="other_mixin"
         child2_computed_other_mixin, exp_not_equal="other_mixin"

+ 24 - 4
reflex/.templates/jinja/web/utils/context.js.jinja2

@@ -25,11 +25,31 @@ export const clientStorage = {}
 
 
 {% if state_name %}
 {% if state_name %}
 export const state_name = "{{state_name}}"
 export const state_name = "{{state_name}}"
-export const onLoadInternalEvent = () => [
-    Event('{{state_name}}.{{const.update_vars_internal}}', {vars: hydrateClientStorage(clientStorage)}),
-    Event('{{state_name}}.{{const.on_load_internal}}')
-]
 
 
+// Theses events are triggered on initial load and each page navigation.
+export const onLoadInternalEvent = () => {
+    const internal_events = [];
+
+    // Get tracked cookie and local storage vars to send to the backend.
+    const client_storage_vars = hydrateClientStorage(clientStorage);
+    // But only send the vars if any are actually set in the browser.
+    if (client_storage_vars && Object.keys(client_storage_vars).length !== 0) {
+        internal_events.push(
+            Event(
+                '{{state_name}}.{{const.update_vars_internal}}',
+                {vars: client_storage_vars},
+            ),
+        );
+    }
+
+    // `on_load_internal` triggers the correct on_load event(s) for the current page.
+    // If the page does not define any on_load event, this will just set `is_hydrated = true`.
+    internal_events.push(Event('{{state_name}}.{{const.on_load_internal}}'));
+
+    return internal_events;
+}
+
+// The following events are sent when the websocket connects or reconnects.
 export const initialEvents = () => [
 export const initialEvents = () => [
     Event('{{state_name}}.{{const.hydrate}}'),
     Event('{{state_name}}.{{const.hydrate}}'),
     ...onLoadInternalEvent()
     ...onLoadInternalEvent()

+ 1 - 1
reflex/.templates/web/utils/state.js

@@ -587,7 +587,7 @@ export const useEventLoop = (
       if (storage_to_state_map[e.key]) {
       if (storage_to_state_map[e.key]) {
         const vars = {}
         const vars = {}
         vars[storage_to_state_map[e.key]] = e.newValue
         vars[storage_to_state_map[e.key]] = e.newValue
-        const event = Event(`${state_name}.update_vars_internal`, {vars: vars})
+        const event = Event(`${state_name}.update_vars_internal_state.update_vars_internal`, {vars: vars})
         addEvents([event], e);
         addEvents([event], e);
       }
       }
     };
     };

+ 6 - 4
reflex/app.py

@@ -69,9 +69,11 @@ from reflex.state import (
     State,
     State,
     StateManager,
     StateManager,
     StateUpdate,
     StateUpdate,
+    _substate_key,
     code_uses_state_contexts,
     code_uses_state_contexts,
 )
 )
 from reflex.utils import console, exceptions, format, prerequisites, types
 from reflex.utils import console, exceptions, format, prerequisites, types
+from reflex.utils.exec import is_testing_env
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
 
 
 # Define custom types.
 # Define custom types.
@@ -159,10 +161,9 @@ class App(Base):
             )
             )
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         state_subclasses = BaseState.__subclasses__()
         state_subclasses = BaseState.__subclasses__()
-        is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
 
 
         # Special case to allow test cases have multiple subclasses of rx.BaseState.
         # Special case to allow test cases have multiple subclasses of rx.BaseState.
-        if not is_testing_env:
+        if not is_testing_env():
             # Only one Base State class is allowed.
             # Only one Base State class is allowed.
             if len(state_subclasses) > 1:
             if len(state_subclasses) > 1:
                 raise ValueError(
                 raise ValueError(
@@ -176,7 +177,8 @@ class App(Base):
                     deprecation_version="0.3.5",
                     deprecation_version="0.3.5",
                     removal_version="0.5.0",
                     removal_version="0.5.0",
                 )
                 )
-            if len(State.class_subclasses) > 0:
+            # 2 substates are built-in and not considered when determining if app is stateless.
+            if len(State.class_subclasses) > 2:
                 self.state = State
                 self.state = State
         # Get the config
         # Get the config
         config = get_config()
         config = get_config()
@@ -1002,7 +1004,7 @@ def upload(app: App):
             )
             )
 
 
         # Get the state for the session.
         # Get the state for the session.
-        substate_token = token + "_" + handler.rpartition(".")[0]
+        substate_token = _substate_key(token, handler.rpartition(".")[0])
         state = await app.state_manager.get_state(substate_token)
         state = await app.state_manager.get_state(substate_token)
 
 
         # get the current session ID
         # get the current session ID

+ 2 - 2
reflex/compiler/utils.py

@@ -138,12 +138,12 @@ def compile_state(state: Type[BaseState]) -> dict:
         A dictionary of the compiled state.
         A dictionary of the compiled state.
     """
     """
     try:
     try:
-        initial_state = state().dict(initial=True)
+        initial_state = state(_reflex_internal_init=True).dict(initial=True)
     except Exception as e:
     except Exception as e:
         console.warn(
         console.warn(
             f"Failed to compile initial state with computed vars, excluding them: {e}"
             f"Failed to compile initial state with computed vars, excluding them: {e}"
         )
         )
-        initial_state = state().dict(include_computed=False)
+        initial_state = state(_reflex_internal_init=True).dict(include_computed=False)
     return format.format_state(initial_state)
     return format.format_state(initial_state)
 
 
 
 

+ 2 - 2
reflex/constants/compiler.py

@@ -59,9 +59,9 @@ class CompileVars(SimpleNamespace):
     # The name of the function for converting a dict to an event.
     # The name of the function for converting a dict to an event.
     TO_EVENT = "Event"
     TO_EVENT = "Event"
     # The name of the internal on_load event.
     # The name of the internal on_load event.
-    ON_LOAD_INTERNAL = "on_load_internal"
+    ON_LOAD_INTERNAL = "on_load_internal_state.on_load_internal"
     # The name of the internal event to update generic state vars.
     # The name of the internal event to update generic state vars.
-    UPDATE_VARS_INTERNAL = "update_vars_internal"
+    UPDATE_VARS_INTERNAL = "update_vars_internal_state.update_vars_internal"
 
 
 
 
 class PageNames(SimpleNamespace):
 class PageNames(SimpleNamespace):

+ 524 - 105
reflex/state.py

@@ -8,7 +8,6 @@ import copy
 import functools
 import functools
 import inspect
 import inspect
 import json
 import json
-import os
 import traceback
 import traceback
 import urllib.parse
 import urllib.parse
 import uuid
 import uuid
@@ -45,6 +44,7 @@ from reflex.event import (
 )
 )
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
+from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import SerializedType, serialize, serializer
 from reflex.utils.serializers import SerializedType, serialize, serializer
 from reflex.vars import BaseVar, ComputedVar, Var, computed_var
 from reflex.vars import BaseVar, ComputedVar, Var, computed_var
 
 
@@ -151,9 +151,45 @@ RESERVED_BACKEND_VAR_NAMES = {
     "_substate_var_dependencies",
     "_substate_var_dependencies",
     "_always_dirty_computed_vars",
     "_always_dirty_computed_vars",
     "_always_dirty_substates",
     "_always_dirty_substates",
+    "_was_touched",
 }
 }
 
 
 
 
+def _substate_key(
+    token: str,
+    state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
+) -> str:
+    """Get the substate key.
+
+    Args:
+        token: The token of the state.
+        state_cls_or_name: The state class/instance or name or sequence of name parts.
+
+    Returns:
+        The substate key.
+    """
+    if isinstance(state_cls_or_name, BaseState) or (
+        isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
+    ):
+        state_cls_or_name = state_cls_or_name.get_full_name()
+    elif isinstance(state_cls_or_name, (list, tuple)):
+        state_cls_or_name = ".".join(state_cls_or_name)
+    return f"{token}_{state_cls_or_name}"
+
+
+def _split_substate_key(substate_key: str) -> tuple[str, str]:
+    """Split the substate key into token and state name.
+
+    Args:
+        substate_key: The substate key.
+
+    Returns:
+        Tuple of token and state name.
+    """
+    token, _, state_name = substate_key.partition("_")
+    return token, state_name
+
+
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
     """The state of the app."""
 
 
@@ -214,29 +250,46 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # The router data for the current page
     # The router data for the current page
     router: RouterData = RouterData()
     router: RouterData = RouterData()
 
 
+    # Whether the state has ever been touched since instantiation.
+    _was_touched: bool = False
+
     def __init__(
     def __init__(
         self,
         self,
         *args,
         *args,
         parent_state: BaseState | None = None,
         parent_state: BaseState | None = None,
         init_substates: bool = True,
         init_substates: bool = True,
+        _reflex_internal_init: bool = False,
         **kwargs,
         **kwargs,
     ):
     ):
         """Initialize the state.
         """Initialize the state.
 
 
+        DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead.
+
         Args:
         Args:
             *args: The args to pass to the Pydantic init method.
             *args: The args to pass to the Pydantic init method.
             parent_state: The parent state.
             parent_state: The parent state.
             init_substates: Whether to initialize the substates in this instance.
             init_substates: Whether to initialize the substates in this instance.
+            _reflex_internal_init: A flag to indicate that the state is being initialized by the framework.
             **kwargs: The kwargs to pass to the Pydantic init method.
             **kwargs: The kwargs to pass to the Pydantic init method.
 
 
+        Raises:
+            RuntimeError: If the state is instantiated directly by end user.
         """
         """
+        if not _reflex_internal_init and not is_testing_env():
+            raise RuntimeError(
+                "State classes should not be instantiated directly in a Reflex app. "
+                "See https://reflex.dev/docs/state for further information."
+            )
         kwargs["parent_state"] = parent_state
         kwargs["parent_state"] = parent_state
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
 
 
         # Setup the substates (for memory state manager only).
         # Setup the substates (for memory state manager only).
         if init_substates:
         if init_substates:
             for substate in self.get_substates():
             for substate in self.get_substates():
-                self.substates[substate.get_name()] = substate(parent_state=self)
+                self.substates[substate.get_name()] = substate(
+                    parent_state=self,
+                    _reflex_internal_init=True,
+                )
         # Convert the event handlers to functions.
         # Convert the event handlers to functions.
         self._init_event_handlers()
         self._init_event_handlers()
 
 
@@ -287,7 +340,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         Raises:
         Raises:
             ValueError: If a substate class shadows another.
             ValueError: If a substate class shadows another.
         """
         """
-        is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
         super().__init_subclass__(**kwargs)
         super().__init_subclass__(**kwargs)
         # Event handlers should not shadow builtin state methods.
         # Event handlers should not shadow builtin state methods.
         cls._check_overridden_methods()
         cls._check_overridden_methods()
@@ -295,6 +347,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Reset subclass tracking for this class.
         # Reset subclass tracking for this class.
         cls.class_subclasses = set()
         cls.class_subclasses = set()
 
 
+        # Reset dirty substate tracking for this class.
+        cls._always_dirty_substates = set()
+
         # Get the parent vars.
         # Get the parent vars.
         parent_state = cls.get_parent_state()
         parent_state = cls.get_parent_state()
         if parent_state is not None:
         if parent_state is not None:
@@ -303,7 +358,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
             # Check if another substate class with the same name has already been defined.
             # Check if another substate class with the same name has already been defined.
             if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
             if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
-                if is_testing_env:
+                if is_testing_env():
                     # Clear existing subclass with same name when app is reloaded via
                     # Clear existing subclass with same name when app is reloaded via
                     # utils.prerequisites.get_app(reload=True)
                     # utils.prerequisites.get_app(reload=True)
                     parent_state.class_subclasses = set(
                     parent_state.class_subclasses = set(
@@ -325,6 +380,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             name: value
             name: value
             for name, value in cls.__dict__.items()
             for name, value in cls.__dict__.items()
             if types.is_backend_variable(name, cls)
             if types.is_backend_variable(name, cls)
+            and name not in RESERVED_BACKEND_VAR_NAMES
             and name not in cls.inherited_backend_vars
             and name not in cls.inherited_backend_vars
             and not isinstance(value, FunctionType)
             and not isinstance(value, FunctionType)
         }
         }
@@ -484,7 +540,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         )
         )
 
 
         # Any substate containing a ComputedVar with cache=False always needs to be recomputed
         # Any substate containing a ComputedVar with cache=False always needs to be recomputed
-        cls._always_dirty_substates = set()
         if cls._always_dirty_computed_vars:
         if cls._always_dirty_computed_vars:
             # Tell parent classes that this substate has always dirty computed vars
             # Tell parent classes that this substate has always dirty computed vars
             state_name = cls.get_name()
             state_name = cls.get_name()
@@ -923,8 +978,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             **super().__getattribute__("inherited_vars"),
             **super().__getattribute__("inherited_vars"),
             **super().__getattribute__("inherited_backend_vars"),
             **super().__getattribute__("inherited_backend_vars"),
         }
         }
-        if name in inherited_vars:
-            return getattr(super().__getattribute__("parent_state"), name)
+
+        # For now, handle router_data updates as a special case.
+        if name in inherited_vars or name == constants.ROUTER_DATA:
+            parent_state = super().__getattribute__("parent_state")
+            if parent_state is not None:
+                return getattr(parent_state, name)
 
 
         backend_vars = super().__getattribute__("_backend_vars")
         backend_vars = super().__getattribute__("_backend_vars")
         if name in backend_vars:
         if name in backend_vars:
@@ -980,9 +1039,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         if name == constants.ROUTER_DATA:
         if name == constants.ROUTER_DATA:
             self.dirty_vars.add(name)
             self.dirty_vars.add(name)
             self._mark_dirty()
             self._mark_dirty()
-            # propagate router_data updates down the state tree
-            for substate in self.substates.values():
-                setattr(substate, name, value)
 
 
     def reset(self):
     def reset(self):
         """Reset all the base vars to their default values."""
         """Reset all the base vars to their default values."""
@@ -1036,6 +1092,170 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             raise ValueError(f"Invalid path: {path}")
             raise ValueError(f"Invalid path: {path}")
         return self.substates[path[0]].get_substate(path[1:])
         return self.substates[path[0]].get_substate(path[1:])
 
 
+    @classmethod
+    def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
+        """Find the name of the nearest common ancestor shared by this and the other state.
+
+        Args:
+            other: The other state.
+
+        Returns:
+            Full name of the nearest common ancestor.
+        """
+        common_ancestor_parts = []
+        for part1, part2 in zip(
+            cls.get_full_name().split("."),
+            other.get_full_name().split("."),
+        ):
+            if part1 != part2:
+                break
+            common_ancestor_parts.append(part1)
+        return ".".join(common_ancestor_parts)
+
+    @classmethod
+    def _determine_missing_parent_states(
+        cls, target_state_cls: Type[BaseState]
+    ) -> tuple[str, list[str]]:
+        """Determine the missing parent states between the target_state_cls and common ancestor of this state.
+
+        Args:
+            target_state_cls: The class of the state to find missing parent states for.
+
+        Returns:
+            The name of the common ancestor and the list of missing parent states.
+        """
+        common_ancestor_name = cls._get_common_ancestor(target_state_cls)
+        common_ancestor_parts = common_ancestor_name.split(".")
+        target_state_parts = tuple(target_state_cls.get_full_name().split("."))
+        relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
+
+        # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
+        fetch_parent_states = [common_ancestor_name]
+        for ix, relative_parent_state_name in enumerate(relative_target_state_parts):
+            fetch_parent_states.append(
+                ".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name])
+            )
+
+        return common_ancestor_name, fetch_parent_states[1:-1]
+
+    def _get_parent_states(self) -> list[tuple[str, BaseState]]:
+        """Get all parent state instances up to the root of the state tree.
+
+        Returns:
+            A list of tuples containing the name and the instance of each parent state.
+        """
+        parent_states_with_name = []
+        parent_state = self
+        while parent_state.parent_state is not None:
+            parent_state = parent_state.parent_state
+            parent_states_with_name.append((parent_state.get_full_name(), parent_state))
+        return parent_states_with_name
+
+    async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
+        """Populate substates in the tree between the target_state_cls and common ancestor of this state.
+
+        Args:
+            target_state_cls: The class of the state to populate parent states for.
+
+        Returns:
+            The parent state instance of target_state_cls.
+
+        Raises:
+            RuntimeError: If redis is not used in this backend process.
+        """
+        state_manager = get_state_manager()
+        if not isinstance(state_manager, StateManagerRedis):
+            raise RuntimeError(
+                f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
+                "(All states should already be available -- this is likely a bug).",
+            )
+
+        # Find the missing parent states up to the common ancestor.
+        (
+            common_ancestor_name,
+            missing_parent_states,
+        ) = self._determine_missing_parent_states(target_state_cls)
+
+        # Fetch all missing parent states and link them up to the common ancestor.
+        parent_states_by_name = dict(self._get_parent_states())
+        parent_state = parent_states_by_name[common_ancestor_name]
+        for parent_state_name in missing_parent_states:
+            parent_state = await state_manager.get_state(
+                token=_substate_key(
+                    self.router.session.client_token, parent_state_name
+                ),
+                top_level=False,
+                get_substates=False,
+                parent_state=parent_state,
+            )
+
+        # Return the direct parent of target_state_cls for subsequent linking.
+        return parent_state
+
+    def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
+        """Get a state instance from the cache.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The instance of state_cls associated with this state's client_token.
+        """
+        if self.parent_state is None:
+            root_state = self
+        else:
+            root_state = self._get_parent_states()[-1][1]
+        return root_state.get_substate(state_cls.get_full_name().split("."))
+
+    async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
+        """Get a state instance from redis.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The instance of state_cls associated with this state's client_token.
+
+        Raises:
+            RuntimeError: If redis is not used in this backend process.
+        """
+        # Fetch all missing parent states from redis.
+        parent_state_of_state_cls = await self._populate_parent_states(state_cls)
+
+        # Then get the target state and all its substates.
+        state_manager = get_state_manager()
+        if not isinstance(state_manager, StateManagerRedis):
+            raise RuntimeError(
+                f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
+                "(All states should already be available -- this is likely a bug).",
+            )
+        return await state_manager.get_state(
+            token=_substate_key(self.router.session.client_token, state_cls),
+            top_level=False,
+            get_substates=True,
+            parent_state=parent_state_of_state_cls,
+        )
+
+    async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
+        """Get an instance of the state associated with this token.
+
+        Allows for arbitrary access to sibling states from within an event handler.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The instance of state_cls associated with this state's client_token.
+        """
+        # Fast case - if this state instance is already cached, get_substate from root state.
+        try:
+            return self._get_state_from_cache(state_cls)
+        except ValueError:
+            pass
+
+        # Slow case - fetch missing parent states from redis.
+        return await self._get_state_from_redis(state_cls)
+
     def _get_event_handler(
     def _get_event_handler(
         self, event: Event
         self, event: Event
     ) -> tuple[BaseState | StateProxy, EventHandler]:
     ) -> tuple[BaseState | StateProxy, EventHandler]:
@@ -1238,6 +1458,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             for cvar in self._computed_var_dependencies[dirty_var]
             for cvar in self._computed_var_dependencies[dirty_var]
         )
         )
 
 
+    @classmethod
+    def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
+        """Determine substates which could be affected by dirty vars in this state.
+
+        Returns:
+            Set of State classes that may need to be fetched to recalc computed vars.
+        """
+        # _always_dirty_substates need to be fetched to recalc computed vars.
+        fetch_substates = set(
+            cls.get_class_substate(tuple(substate_name.split(".")))
+            for substate_name in cls._always_dirty_substates
+        )
+        # Substates with cached vars also need to be fetched.
+        for dependent_substates in cls._substate_var_dependencies.values():
+            fetch_substates.update(
+                set(
+                    cls.get_class_substate(tuple(substate_name.split(".")))
+                    for substate_name in dependent_substates
+                )
+            )
+        return fetch_substates
+
     def get_delta(self) -> Delta:
     def get_delta(self) -> Delta:
         """Get the delta for the state.
         """Get the delta for the state.
 
 
@@ -1269,8 +1511,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Recursively find the substate deltas.
         # Recursively find the substate deltas.
         substates = self.substates
         substates = self.substates
         for substate in self.dirty_substates.union(self._always_dirty_substates):
         for substate in self.dirty_substates.union(self._always_dirty_substates):
-            if substate not in substates:
-                continue  # substate not loaded at this time, no delta
             delta.update(substates[substate].get_delta())
             delta.update(substates[substate].get_delta())
 
 
         # Format the delta.
         # Format the delta.
@@ -1292,20 +1532,45 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # have to mark computed vars dirty to allow access to newly computed
         # have to mark computed vars dirty to allow access to newly computed
         # values within the same ComputedVar function
         # values within the same ComputedVar function
         self._mark_dirty_computed_vars()
         self._mark_dirty_computed_vars()
+        self._mark_dirty_substates()
 
 
-        # Propagate dirty var / computed var status into substates
+    def _mark_dirty_substates(self):
+        """Propagate dirty var / computed var status into substates."""
         substates = self.substates
         substates = self.substates
         for var in self.dirty_vars:
         for var in self.dirty_vars:
             for substate_name in self._substate_var_dependencies[var]:
             for substate_name in self._substate_var_dependencies[var]:
                 self.dirty_substates.add(substate_name)
                 self.dirty_substates.add(substate_name)
-                if substate_name not in substates:
-                    continue
                 substate = substates[substate_name]
                 substate = substates[substate_name]
                 substate.dirty_vars.add(var)
                 substate.dirty_vars.add(var)
                 substate._mark_dirty()
                 substate._mark_dirty()
 
 
+    def _update_was_touched(self):
+        """Update the _was_touched flag based on dirty_vars."""
+        if self.dirty_vars and not self._was_touched:
+            for var in self.dirty_vars:
+                if var in self.base_vars or var in self._backend_vars:
+                    self._was_touched = True
+                    break
+
+    def _get_was_touched(self) -> bool:
+        """Check current dirty_vars and flag to determine if state instance was modified.
+
+        If any dirty vars belong to this state, mark _was_touched.
+
+        This flag determines whether this state instance should be persisted to redis.
+
+        Returns:
+            Whether this state instance was ever modified.
+        """
+        # Ensure the flag is up to date based on the current dirty_vars
+        self._update_was_touched()
+        return self._was_touched
+
     def _clean(self):
     def _clean(self):
         """Reset the dirty vars."""
         """Reset the dirty vars."""
+        # Update touched status before cleaning dirty_vars.
+        self._update_was_touched()
+
         # Recursively clean the substates.
         # Recursively clean the substates.
         for substate in self.dirty_substates:
         for substate in self.dirty_substates:
             if substate not in self.substates:
             if substate not in self.substates:
@@ -1422,6 +1687,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         state["__dict__"] = state["__dict__"].copy()
         state["__dict__"] = state["__dict__"].copy()
         state["__dict__"]["parent_state"] = None
         state["__dict__"]["parent_state"] = None
         state["__dict__"]["substates"] = {}
         state["__dict__"]["substates"] = {}
+        state["__dict__"].pop("_was_touched", None)
         return state
         return state
 
 
 
 
@@ -1431,6 +1697,35 @@ class State(BaseState):
     # The hydrated bool.
     # The hydrated bool.
     is_hydrated: bool = False
     is_hydrated: bool = False
 
 
+
+class UpdateVarsInternalState(State):
+    """Substate for handling internal state var updates."""
+
+    async def update_vars_internal(self, vars: dict[str, Any]) -> None:
+        """Apply updates to fully qualified state vars.
+
+        The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
+        and each value will be set on the appropriate substate instance.
+
+        This function is primarily used to apply cookie and local storage
+        updates from the frontend to the appropriate substate.
+
+        Args:
+            vars: The fully qualified vars and values to update.
+        """
+        for var, value in vars.items():
+            state_name, _, var_name = var.rpartition(".")
+            var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
+            var_state = await self.get_state(var_state_cls)
+            setattr(var_state, var_name, value)
+
+
+class OnLoadInternalState(State):
+    """Substate for handling on_load event enumeration.
+
+    This is a separate substate to avoid deserializing the entire state tree for every page navigation.
+    """
+
     def on_load_internal(self) -> list[Event | EventSpec] | None:
     def on_load_internal(self) -> list[Event | EventSpec] | None:
         """Queue on_load handlers for the current page.
         """Queue on_load handlers for the current page.
 
 
@@ -1442,6 +1737,9 @@ class State(BaseState):
         load_events = app.get_load_events(self.router.page.path)
         load_events = app.get_load_events(self.router.page.path)
         if not load_events and self.is_hydrated:
         if not load_events and self.is_hydrated:
             return  # Fast path for page-to-page navigation
             return  # Fast path for page-to-page navigation
+        if not load_events:
+            self.is_hydrated = True
+            return  # Fast path for initial hydrate with no on_load events defined.
         self.is_hydrated = False
         self.is_hydrated = False
         return [
         return [
             *fix_events(
             *fix_events(
@@ -1449,26 +1747,9 @@ class State(BaseState):
                 self.router.session.client_token,
                 self.router.session.client_token,
                 router_data=self.router_data,
                 router_data=self.router_data,
             ),
             ),
-            type(self).set_is_hydrated(True),  # type: ignore
+            State.set_is_hydrated(True),  # type: ignore
         ]
         ]
 
 
-    def update_vars_internal(self, vars: dict[str, Any]) -> None:
-        """Apply updates to fully qualified state vars.
-
-        The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
-        and each value will be set on the appropriate substate instance.
-
-        This function is primarily used to apply cookie and local storage
-        updates from the frontend to the appropriate substate.
-
-        Args:
-            vars: The fully qualified vars and values to update.
-        """
-        for var, value in vars.items():
-            state_name, _, var_name = var.rpartition(".")
-            var_state = self.get_substate(state_name.split("."))
-            setattr(var_state, var_name, value)
-
 
 
 class StateProxy(wrapt.ObjectProxy):
 class StateProxy(wrapt.ObjectProxy):
     """Proxy of a state instance to control mutability of vars for a background task.
     """Proxy of a state instance to control mutability of vars for a background task.
@@ -1522,9 +1803,10 @@ class StateProxy(wrapt.ObjectProxy):
             This StateProxy instance in mutable mode.
             This StateProxy instance in mutable mode.
         """
         """
         self._self_actx = self._self_app.modify_state(
         self._self_actx = self._self_app.modify_state(
-            self.__wrapped__.router.session.client_token
-            + "_"
-            + ".".join(self._self_substate_path)
+            token=_substate_key(
+                self.__wrapped__.router.session.client_token,
+                self._self_substate_path,
+            )
         )
         )
         mutable_state = await self._self_actx.__aenter__()
         mutable_state = await self._self_actx.__aenter__()
         super().__setattr__(
         super().__setattr__(
@@ -1574,7 +1856,15 @@ class StateProxy(wrapt.ObjectProxy):
 
 
         Returns:
         Returns:
             The value of the attribute.
             The value of the attribute.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
         """
         """
+        if name in ["substates", "parent_state"] and not self._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
         value = super().__getattr__(name)
         value = super().__getattr__(name)
         if not name.startswith("_self_") and isinstance(value, MutableProxy):
         if not name.startswith("_self_") and isinstance(value, MutableProxy):
             # ensure mutations to these containers are blocked unless proxy is _mutable
             # ensure mutations to these containers are blocked unless proxy is _mutable
@@ -1622,6 +1912,60 @@ class StateProxy(wrapt.ObjectProxy):
             "manager. Use `async with self` to modify state."
             "manager. Use `async with self` to modify state."
         )
         )
 
 
+    def get_substate(self, path: Sequence[str]) -> BaseState:
+        """Only allow substate access with lock held.
+
+        Args:
+            path: The path to the substate.
+
+        Returns:
+            The substate.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not self._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return self.__wrapped__.get_substate(path)
+
+    async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
+        """Get an instance of the state associated with this token.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The state.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not self._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return await self.__wrapped__.get_state(state_cls)
+
+    def _as_state_update(self, *args, **kwargs) -> StateUpdate:
+        """Temporarily allow mutability to access parent_state.
+
+        Args:
+            *args: The args to pass to the underlying state instance.
+            **kwargs: The kwargs to pass to the underlying state instance.
+
+        Returns:
+            The state update.
+        """
+        self._self_mutable = True
+        try:
+            return self.__wrapped__._as_state_update(*args, **kwargs)
+        finally:
+            self._self_mutable = False
+
 
 
 class StateUpdate(Base):
 class StateUpdate(Base):
     """A state update sent to the frontend."""
     """A state update sent to the frontend."""
@@ -1722,9 +2066,9 @@ class StateManagerMemory(StateManager):
             The state for the token.
             The state for the token.
         """
         """
         # Memory state manager ignores the substate suffix and always returns the top-level state.
         # Memory state manager ignores the substate suffix and always returns the top-level state.
-        token = token.partition("_")[0]
+        token = _split_substate_key(token)[0]
         if token not in self.states:
         if token not in self.states:
-            self.states[token] = self.state()
+            self.states[token] = self.state(_reflex_internal_init=True)
         return self.states[token]
         return self.states[token]
 
 
     async def set_state(self, token: str, state: BaseState):
     async def set_state(self, token: str, state: BaseState):
@@ -1747,7 +2091,7 @@ class StateManagerMemory(StateManager):
             The state for the token.
             The state for the token.
         """
         """
         # Memory state manager ignores the substate suffix and always returns the top-level state.
         # Memory state manager ignores the substate suffix and always returns the top-level state.
-        token = token.partition("_")[0]
+        token = _split_substate_key(token)[0]
         if token not in self._states_locks:
         if token not in self._states_locks:
             async with self._state_manager_lock:
             async with self._state_manager_lock:
                 if token not in self._states_locks:
                 if token not in self._states_locks:
@@ -1787,6 +2131,81 @@ class StateManagerRedis(StateManager):
         b"evicted",
         b"evicted",
     }
     }
 
 
+    def _get_root_state(self, state: BaseState) -> BaseState:
+        """Chase parent_state pointers to find an instance of the top-level state.
+
+        Args:
+            state: The state to start from.
+
+        Returns:
+            An instance of the top-level state (self.state).
+        """
+        while type(state) != self.state and state.parent_state is not None:
+            state = state.parent_state
+        return state
+
+    async def _get_parent_state(self, token: str) -> BaseState | None:
+        """Get the parent state for the state requested in the token.
+
+        Args:
+            token: The token to get the state for (_substate_key).
+
+        Returns:
+            The parent state for the state requested by the token or None if there is no such parent.
+        """
+        parent_state = None
+        client_token, state_path = _split_substate_key(token)
+        parent_state_name = state_path.rpartition(".")[0]
+        if parent_state_name:
+            # Retrieve the parent state to populate event handlers onto this substate.
+            parent_state = await self.get_state(
+                token=_substate_key(client_token, parent_state_name),
+                top_level=False,
+                get_substates=False,
+            )
+        return parent_state
+
+    async def _populate_substates(
+        self,
+        token: str,
+        state: BaseState,
+        all_substates: bool = False,
+    ):
+        """Fetch and link substates for the given state instance.
+
+        There is no return value; the side-effect is that `state` will have `substates` populated,
+        and each substate will have its `parent_state` set to `state`.
+
+        Args:
+            token: The token to get the state for.
+            state: The state instance to populate substates for.
+            all_substates: Whether to fetch all substates or just required substates.
+        """
+        client_token, _ = _split_substate_key(token)
+
+        if all_substates:
+            # All substates are requested.
+            fetch_substates = state.get_substates()
+        else:
+            # Only _potentially_dirty_substates need to be fetched to recalc computed vars.
+            fetch_substates = state._potentially_dirty_substates()
+
+        tasks = {}
+        # Retrieve the necessary substates from redis.
+        for substate_cls in fetch_substates:
+            substate_name = substate_cls.get_name()
+            tasks[substate_name] = asyncio.create_task(
+                self.get_state(
+                    token=_substate_key(client_token, substate_cls),
+                    top_level=False,
+                    get_substates=all_substates,
+                    parent_state=state,
+                )
+            )
+
+        for substate_name, substate_task in tasks.items():
+            state.substates[substate_name] = await substate_task
+
     async def get_state(
     async def get_state(
         self,
         self,
         token: str,
         token: str,
@@ -1798,8 +2217,8 @@ class StateManagerRedis(StateManager):
 
 
         Args:
         Args:
             token: The token to get the state for.
             token: The token to get the state for.
-            top_level: If true, return an instance of the top-level state.
-            get_substates: If true, also retrieve substates
+            top_level: If true, return an instance of the top-level state (self.state).
+            get_substates: If true, also retrieve substates.
             parent_state: If provided, use this parent_state instead of getting it from redis.
             parent_state: If provided, use this parent_state instead of getting it from redis.
 
 
         Returns:
         Returns:
@@ -1809,7 +2228,7 @@ class StateManagerRedis(StateManager):
             RuntimeError: when the state_cls is not specified in the token
             RuntimeError: when the state_cls is not specified in the token
         """
         """
         # Split the actual token from the fully qualified substate name.
         # Split the actual token from the fully qualified substate name.
-        client_token, _, state_path = token.partition("_")
+        _, state_path = _split_substate_key(token)
         if state_path:
         if state_path:
             # Get the State class associated with the given path.
             # Get the State class associated with the given path.
             state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
             state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
@@ -1825,66 +2244,49 @@ class StateManagerRedis(StateManager):
             # Deserialize the substate.
             # Deserialize the substate.
             state = cloudpickle.loads(redis_state)
             state = cloudpickle.loads(redis_state)
 
 
-            # Populate parent and substates if requested.
+            # Populate parent state if missing and requested.
             if parent_state is None:
             if parent_state is None:
-                # Retrieve the parent state from redis.
-                parent_state_name = state_path.rpartition(".")[0]
-                if parent_state_name:
-                    parent_state_key = token.rpartition(".")[0]
-                    parent_state = await self.get_state(
-                        parent_state_key, top_level=False, get_substates=False
-                    )
+                parent_state = await self._get_parent_state(token)
             # Set up Bidirectional linkage between this state and its parent.
             # Set up Bidirectional linkage between this state and its parent.
             if parent_state is not None:
             if parent_state is not None:
                 parent_state.substates[state.get_name()] = state
                 parent_state.substates[state.get_name()] = state
                 state.parent_state = parent_state
                 state.parent_state = parent_state
-            if get_substates:
-                # Retrieve all substates from redis.
-                for substate_cls in state_cls.get_substates():
-                    substate_name = substate_cls.get_name()
-                    substate_key = token + "." + substate_name
-                    state.substates[substate_name] = await self.get_state(
-                        substate_key, top_level=False, parent_state=state
-                    )
+            # Populate substates if requested.
+            await self._populate_substates(token, state, all_substates=get_substates)
+
             # To retain compatibility with previous implementation, by default, we return
             # To retain compatibility with previous implementation, by default, we return
             # the top-level state by chasing `parent_state` pointers up the tree.
             # the top-level state by chasing `parent_state` pointers up the tree.
             if top_level:
             if top_level:
-                while type(state) != self.state and state.parent_state is not None:
-                    state = state.parent_state
+                return self._get_root_state(state)
             return state
             return state
 
 
-        # Key didn't exist so we have to create a new entry for this token.
+        # TODO: dedupe the following logic with the above block
+        # Key didn't exist so we have to create a new instance for this token.
         if parent_state is None:
         if parent_state is None:
-            parent_state_name = state_path.rpartition(".")[0]
-            if parent_state_name:
-                # Retrieve the parent state to populate event handlers onto this substate.
-                parent_state_key = client_token + "_" + parent_state_name
-                parent_state = await self.get_state(
-                    parent_state_key, top_level=False, get_substates=False
-                )
-        # Persist the new state class to redis.
-        await self.set_state(
-            token,
-            state_cls(
-                parent_state=parent_state,
-                init_substates=False,
-            ),
-        )
-        # After creating the state key, recursively call `get_state` to populate substates.
-        return await self.get_state(
-            token,
-            top_level=top_level,
-            get_substates=get_substates,
+            parent_state = await self._get_parent_state(token)
+        # Instantiate the new state class (but don't persist it yet).
+        state = state_cls(
             parent_state=parent_state,
             parent_state=parent_state,
+            init_substates=False,
+            _reflex_internal_init=True,
         )
         )
+        # Set up Bidirectional linkage between this state and its parent.
+        if parent_state is not None:
+            parent_state.substates[state.get_name()] = state
+            state.parent_state = parent_state
+        # Populate substates for the newly created state.
+        await self._populate_substates(token, state, all_substates=get_substates)
+        # To retain compatibility with previous implementation, by default, we return
+        # the top-level state by chasing `parent_state` pointers up the tree.
+        if top_level:
+            return self._get_root_state(state)
+        return state
 
 
     async def set_state(
     async def set_state(
         self,
         self,
         token: str,
         token: str,
         state: BaseState,
         state: BaseState,
         lock_id: bytes | None = None,
         lock_id: bytes | None = None,
-        set_substates: bool = True,
-        set_parent_state: bool = True,
     ):
     ):
         """Set the state for a token.
         """Set the state for a token.
 
 
@@ -1892,11 +2294,10 @@ class StateManagerRedis(StateManager):
             token: The token to set the state for.
             token: The token to set the state for.
             state: The state to set.
             state: The state to set.
             lock_id: If provided, the lock_key must be set to this value to set the state.
             lock_id: If provided, the lock_key must be set to this value to set the state.
-            set_substates: If True, write substates to redis
-            set_parent_state: If True, write parent state to redis
 
 
         Raises:
         Raises:
             LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
             LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
+            RuntimeError: If the state instance doesn't match the state name in the token.
         """
         """
         # Check that we're holding the lock.
         # Check that we're holding the lock.
         if (
         if (
@@ -1908,28 +2309,36 @@ class StateManagerRedis(StateManager):
                 f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
                 f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
                 "or use `@rx.background` decorator for long-running tasks."
                 "or use `@rx.background` decorator for long-running tasks."
             )
             )
-        # Find the substate associated with the token.
-        state_path = token.partition("_")[2]
-        if state_path and state.get_full_name() != state_path:
-            state = state.get_substate(tuple(state_path.split(".")))
-        # Persist the parent state separately, if requested.
-        if state.parent_state is not None and set_parent_state:
-            parent_state_key = token.rpartition(".")[0]
-            await self.set_state(
-                parent_state_key,
-                state.parent_state,
-                lock_id=lock_id,
-                set_substates=False,
+        client_token, substate_name = _split_substate_key(token)
+        # If the substate name on the token doesn't match the instance name, it cannot have a parent.
+        if state.parent_state is not None and state.get_full_name() != substate_name:
+            raise RuntimeError(
+                f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
             )
             )
-        # Persist the substates separately, if requested.
-        if set_substates:
-            for substate_name, substate in state.substates.items():
-                substate_key = token + "." + substate_name
-                await self.set_state(
-                    substate_key, substate, lock_id=lock_id, set_parent_state=False
+
+        # Recursively set_state on all known substates.
+        tasks = []
+        for substate in state.substates.values():
+            tasks.append(
+                asyncio.create_task(
+                    self.set_state(
+                        token=_substate_key(client_token, substate),
+                        state=substate,
+                        lock_id=lock_id,
+                    )
                 )
                 )
+            )
         # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
         # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
-        await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
+        if state._get_was_touched():
+            await self.redis.set(
+                _substate_key(client_token, state),
+                cloudpickle.dumps(state),
+                ex=self.token_expiration,
+            )
+
+        # Wait for substates to be persisted.
+        for t in tasks:
+            await t
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
@@ -1957,7 +2366,7 @@ class StateManagerRedis(StateManager):
             The redis lock key for the token.
             The redis lock key for the token.
         """
         """
         # All substates share the same lock domain, so ignore any substate path suffix.
         # All substates share the same lock domain, so ignore any substate path suffix.
-        client_token = token.partition("_")[0]
+        client_token = _split_substate_key(token)[0]
         return f"{client_token}_lock".encode()
         return f"{client_token}_lock".encode()
 
 
     async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
     async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
@@ -2052,6 +2461,16 @@ class StateManagerRedis(StateManager):
         await self.redis.close(close_connection_pool=True)
         await self.redis.close(close_connection_pool=True)
 
 
 
 
+def get_state_manager() -> StateManager:
+    """Get the state manager for the app that is currently running.
+
+    Returns:
+        The state manager.
+    """
+    app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+    return app.state_manager
+
+
 class ClientStorageBase:
 class ClientStorageBase:
     """Base class for client-side storage."""
     """Base class for client-side storage."""
 
 

+ 6 - 0
reflex/testing.py

@@ -70,6 +70,10 @@ else:
     FRONTEND_POPEN_ARGS["start_new_session"] = True
     FRONTEND_POPEN_ARGS["start_new_session"] = True
 
 
 
 
+# Save a copy of internal substates to reset after each test.
+INTERNAL_STATES = State.class_subclasses.copy()
+
+
 # borrowed from py3.11
 # borrowed from py3.11
 class chdir(contextlib.AbstractContextManager):
 class chdir(contextlib.AbstractContextManager):
     """Non thread-safe context manager to change the current working directory."""
     """Non thread-safe context manager to change the current working directory."""
@@ -220,6 +224,8 @@ class AppHarness:
             reflex.config.get_config(reload=True)
             reflex.config.get_config(reload=True)
             # reset rx.State subclasses
             # reset rx.State subclasses
             State.class_subclasses.clear()
             State.class_subclasses.clear()
+            State.class_subclasses.update(INTERNAL_STATES)
+            State._always_dirty_substates = set()
             State.get_class_substate.cache_clear()
             State.get_class_substate.cache_clear()
             # Ensure the AppHarness test does not skip State assignment due to running via pytest
             # Ensure the AppHarness test does not skip State assignment due to running via pytest
             os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
             os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)

+ 9 - 0
reflex/utils/exec.py

@@ -285,3 +285,12 @@ def output_system_info():
     console.debug(f"Using package executer at: {prerequisites.get_package_manager()}")  # type: ignore
     console.debug(f"Using package executer at: {prerequisites.get_package_manager()}")  # type: ignore
     if system != "Windows":
     if system != "Windows":
         console.debug(f"Unzip path: {path_ops.which('unzip')}")
         console.debug(f"Unzip path: {path_ops.which('unzip')}")
+
+
+def is_testing_env() -> bool:
+    """Whether the app is running in a testing environment.
+
+    Returns:
+        True if the app is running in under pytest.
+    """
+    return constants.PYTEST_CURRENT_TEST in os.environ

+ 10 - 0
reflex/vars.py

@@ -1875,6 +1875,10 @@ class ComputedVar(Var, property):
 
 
         Returns:
         Returns:
             A set of variable names accessed by the given obj.
             A set of variable names accessed by the given obj.
+
+        Raises:
+            ValueError: if the function references the get_state, parent_state, or substates attributes
+                (cannot track deps in a related state, only implicitly via parent state).
         """
         """
         d = set()
         d = set()
         if obj is None:
         if obj is None:
@@ -1898,6 +1902,8 @@ class ComputedVar(Var, property):
         if self_name is None:
         if self_name is None:
             # cannot reference attributes on self if method takes no args
             # cannot reference attributes on self if method takes no args
             return set()
             return set()
+
+        invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
         self_is_top_of_stack = False
         self_is_top_of_stack = False
         for instruction in dis.get_instructions(obj):
         for instruction in dis.get_instructions(obj):
             if (
             if (
@@ -1916,6 +1922,10 @@ class ComputedVar(Var, property):
                     ref_obj = getattr(objclass, instruction.argval)
                     ref_obj = getattr(objclass, instruction.argval)
                 except Exception:
                 except Exception:
                     ref_obj = None
                     ref_obj = None
+                if instruction.argval in invalid_names:
+                    raise ValueError(
+                        f"Cached var {self._var_full_name} cannot access arbitrary state via `{instruction.argval}`."
+                    )
                 if callable(ref_obj):
                 if callable(ref_obj):
                     # recurse into callable attributes
                     # recurse into callable attributes
                     d.update(
                     d.update(

+ 17 - 10
tests/test_app.py

@@ -29,7 +29,15 @@ from reflex.components.radix.themes.typography.text import Text
 from reflex.event import Event
 from reflex.event import Event
 from reflex.middleware import HydrateMiddleware
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
 from reflex.model import Model
-from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
+from reflex.state import (
+    BaseState,
+    OnLoadInternalState,
+    RouterData,
+    State,
+    StateManagerRedis,
+    StateUpdate,
+    _substate_key,
+)
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import format
 from reflex.utils import format
 from reflex.vars import ComputedVar
 from reflex.vars import ComputedVar
@@ -362,7 +370,7 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str):
     assert app.state == test_state
     assert app.state == test_state
 
 
     # Get a state for a given token.
     # Get a state for a given token.
-    state = await app.state_manager.get_state(f"{token}_{test_state.get_full_name()}")
+    state = await app.state_manager.get_state(_substate_key(token, test_state))
     assert isinstance(state, test_state)
     assert isinstance(state, test_state)
     assert state.var == 0  # type: ignore
     assert state.var == 0  # type: ignore
 
 
@@ -766,8 +774,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
     # The App state must be the "root" of the state tree
     # The App state must be the "root" of the state tree
     app = App(state=State)
     app = App(state=State)
     app.event_namespace.emit = AsyncMock()  # type: ignore
     app.event_namespace.emit = AsyncMock()  # type: ignore
-    substate_token = f"{token}_{state.get_full_name()}"
-    current_state = await app.state_manager.get_state(substate_token)
+    current_state = await app.state_manager.get_state(_substate_key(token, state))
     data = b"This is binary data"
     data = b"This is binary data"
 
 
     # Create a binary IO object and write data to it
     # Create a binary IO object and write data to it
@@ -796,7 +803,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
             == StateUpdate(delta=delta, events=[], final=True).json() + "\n"
             == StateUpdate(delta=delta, events=[], final=True).json() + "\n"
         )
         )
 
 
-    current_state = await app.state_manager.get_state(substate_token)
+    current_state = await app.state_manager.get_state(_substate_key(token, state))
     state_dict = current_state.dict()[state.get_full_name()]
     state_dict = current_state.dict()[state.get_full_name()]
     assert state_dict["img_list"] == [
     assert state_dict["img_list"] == [
         "image1.jpg",
         "image1.jpg",
@@ -913,7 +920,7 @@ class DynamicState(BaseState):
         # self.side_effect_counter = self.side_effect_counter + 1
         # self.side_effect_counter = self.side_effect_counter + 1
         return self.dynamic
         return self.dynamic
 
 
-    on_load_internal = State.on_load_internal.fn
+    on_load_internal = OnLoadInternalState.on_load_internal.fn
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -950,7 +957,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     }
     }
     assert constants.ROUTER in app.state()._computed_var_dependencies
     assert constants.ROUTER in app.state()._computed_var_dependencies
 
 
-    substate_token = f"{token}_{DynamicState.get_full_name()}"
+    substate_token = _substate_key(token, DynamicState)
     sid = "mock_sid"
     sid = "mock_sid"
     client_ip = "127.0.0.1"
     client_ip = "127.0.0.1"
     state = await app.state_manager.get_state(substate_token)
     state = await app.state_manager.get_state(substate_token)
@@ -978,7 +985,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     prev_exp_val = ""
     prev_exp_val = ""
     for exp_index, exp_val in enumerate(exp_vals):
     for exp_index, exp_val in enumerate(exp_vals):
         on_load_internal = _event(
         on_load_internal = _event(
-            name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL}",
+            name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}",
             val=exp_val,
             val=exp_val,
         )
         )
         exp_router_data = {
         exp_router_data = {
@@ -1013,8 +1020,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                     name="on_load",
                     name="on_load",
                     val=exp_val,
                     val=exp_val,
                 ),
                 ),
-                _dynamic_state_event(
-                    name="set_is_hydrated",
+                _event(
+                    name="state.set_is_hydrated",
                     payload={"value": True},
                     payload={"value": True},
                     val=exp_val,
                     val=exp_val,
                     router_data={},
                     router_data={},

+ 207 - 17
tests/test_state.py

@@ -23,6 +23,7 @@ from reflex.state import (
     ImmutableStateError,
     ImmutableStateError,
     LockExpiredError,
     LockExpiredError,
     MutableProxy,
     MutableProxy,
+    OnLoadInternalState,
     RouterData,
     RouterData,
     State,
     State,
     StateManager,
     StateManager,
@@ -30,6 +31,7 @@ from reflex.state import (
     StateManagerRedis,
     StateManagerRedis,
     StateProxy,
     StateProxy,
     StateUpdate,
     StateUpdate,
+    _substate_key,
 )
 )
 from reflex.utils import prerequisites, types
 from reflex.utils import prerequisites, types
 from reflex.utils.format import json_dumps
 from reflex.utils.format import json_dumps
@@ -139,6 +141,12 @@ class ChildState2(TestState):
     value: str
     value: str
 
 
 
 
+class ChildState3(TestState):
+    """A child state fixture."""
+
+    value: str
+
+
 class GrandchildState(ChildState):
 class GrandchildState(ChildState):
     """A grandchild state fixture."""
     """A grandchild state fixture."""
 
 
@@ -149,6 +157,32 @@ class GrandchildState(ChildState):
         pass
         pass
 
 
 
 
+class GrandchildState2(ChildState2):
+    """A grandchild state fixture."""
+
+    @rx.cached_var
+    def cached(self) -> str:
+        """A cached var.
+
+        Returns:
+            The value.
+        """
+        return self.value
+
+
+class GrandchildState3(ChildState3):
+    """A great grandchild state fixture."""
+
+    @rx.var
+    def computed(self) -> str:
+        """A computed var.
+
+        Returns:
+            The value.
+        """
+        return self.value
+
+
 class DateTimeState(BaseState):
 class DateTimeState(BaseState):
     """A State with some datetime fields."""
     """A State with some datetime fields."""
 
 
@@ -329,6 +363,9 @@ def test_dict(test_state):
         "test_state.child_state",
         "test_state.child_state",
         "test_state.child_state.grandchild_state",
         "test_state.child_state.grandchild_state",
         "test_state.child_state2",
         "test_state.child_state2",
+        "test_state.child_state2.grandchild_state2",
+        "test_state.child_state3",
+        "test_state.child_state3.grandchild_state3",
     }
     }
     test_state_dict = test_state.dict()
     test_state_dict = test_state.dict()
     assert set(test_state_dict) == substates
     assert set(test_state_dict) == substates
@@ -380,10 +417,11 @@ def test_get_parent_state():
 
 
 def test_get_substates():
 def test_get_substates():
     """Test getting the substates."""
     """Test getting the substates."""
-    assert TestState.get_substates() == {ChildState, ChildState2}
+    assert TestState.get_substates() == {ChildState, ChildState2, ChildState3}
     assert ChildState.get_substates() == {GrandchildState}
     assert ChildState.get_substates() == {GrandchildState}
-    assert ChildState2.get_substates() == set()
+    assert ChildState2.get_substates() == {GrandchildState2}
     assert GrandchildState.get_substates() == set()
     assert GrandchildState.get_substates() == set()
+    assert GrandchildState2.get_substates() == set()
 
 
 
 
 def test_get_name():
 def test_get_name():
@@ -469,8 +507,8 @@ def test_set_parent_and_substates(test_state, child_state, grandchild_state):
         child_state: A child state.
         child_state: A child state.
         grandchild_state: A grandchild state.
         grandchild_state: A grandchild state.
     """
     """
-    assert len(test_state.substates) == 2
-    assert set(test_state.substates) == {"child_state", "child_state2"}
+    assert len(test_state.substates) == 3
+    assert set(test_state.substates) == {"child_state", "child_state2", "child_state3"}
 
 
     assert child_state.parent_state == test_state
     assert child_state.parent_state == test_state
     assert len(child_state.substates) == 1
     assert len(child_state.substates) == 1
@@ -655,7 +693,7 @@ def test_reset(test_state, child_state):
     assert child_state.dirty_vars == {"count", "value"}
     assert child_state.dirty_vars == {"count", "value"}
 
 
     # The dirty substates should be reset.
     # The dirty substates should be reset.
-    assert test_state.dirty_substates == {"child_state", "child_state2"}
+    assert test_state.dirty_substates == {"child_state", "child_state2", "child_state3"}
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -675,7 +713,10 @@ async def test_process_event_simple(test_state):
 
 
     # The delta should contain the changes, including computed vars.
     # The delta should contain the changes, including computed vars.
     # assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
     # assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
-    assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
+    assert update.delta == {
+        "test_state": {"num1": 69, "sum": 72.14, "upper": ""},
+        "test_state.child_state3.grandchild_state3": {"computed": ""},
+    }
     assert update.events == []
     assert update.events == []
 
 
 
 
@@ -700,6 +741,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     assert update.delta == {
     assert update.delta == {
         "test_state": {"sum": 3.14, "upper": ""},
         "test_state": {"sum": 3.14, "upper": ""},
         "test_state.child_state": {"value": "HI", "count": 24},
         "test_state.child_state": {"value": "HI", "count": 24},
+        "test_state.child_state3.grandchild_state3": {"computed": ""},
     }
     }
     test_state._clean()
     test_state._clean()
 
 
@@ -715,6 +757,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     assert update.delta == {
     assert update.delta == {
         "test_state": {"sum": 3.14, "upper": ""},
         "test_state": {"sum": 3.14, "upper": ""},
         "test_state.child_state.grandchild_state": {"value2": "new"},
         "test_state.child_state.grandchild_state": {"value2": "new"},
+        "test_state.child_state3.grandchild_state3": {"computed": ""},
     }
     }
 
 
 
 
@@ -1443,7 +1486,7 @@ def substate_token(state_manager, token):
     Returns:
     Returns:
         Token concatenated with the state_manager's state full_name.
         Token concatenated with the state_manager's state full_name.
     """
     """
-    return f"{token}_{state_manager.state.get_full_name()}"
+    return _substate_key(token, state_manager.state)
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -1545,7 +1588,7 @@ def substate_token_redis(state_manager_redis, token):
     Returns:
     Returns:
         Token concatenated with the state_manager's state full_name.
         Token concatenated with the state_manager's state full_name.
     """
     """
-    return f"{token}_{state_manager_redis.state.get_full_name()}"
+    return _substate_key(token, state_manager_redis.state)
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -1670,6 +1713,22 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
         # cannot directly modify state proxy outside of async context
         # cannot directly modify state proxy outside of async context
         sp.value2 = "16"
         sp.value2 = "16"
 
 
+    with pytest.raises(ImmutableStateError):
+        # Cannot get_state
+        await sp.get_state(ChildState)
+
+    with pytest.raises(ImmutableStateError):
+        # Cannot access get_substate
+        sp.get_substate([])
+
+    with pytest.raises(ImmutableStateError):
+        # Cannot access parent state
+        sp.parent_state.get_name()
+
+    with pytest.raises(ImmutableStateError):
+        # Cannot access substates
+        sp.substates[""]
+
     async with sp:
     async with sp:
         assert sp._self_actx is not None
         assert sp._self_actx is not None
         assert sp._self_mutable  # proxy is mutable inside context
         assert sp._self_mutable  # proxy is mutable inside context
@@ -1685,8 +1744,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     assert sp.value2 == "42"
     assert sp.value2 == "42"
 
 
     # Get the state from the state manager directly and check that the value is updated
     # Get the state from the state manager directly and check that the value is updated
-    gc_token = f"{grandchild_state.get_token()}_{grandchild_state.get_full_name()}"
-    gotten_state = await mock_app.state_manager.get_state(gc_token)
+    gotten_state = await mock_app.state_manager.get_state(
+        _substate_key(grandchild_state.router.session.client_token, grandchild_state)
+    )
     if isinstance(mock_app.state_manager, StateManagerMemory):
     if isinstance(mock_app.state_manager, StateManagerMemory):
         # For in-process store, only one instance of the state exists
         # For in-process store, only one instance of the state exists
         assert gotten_state is parent_state
         assert gotten_state is parent_state
@@ -1710,6 +1770,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
             grandchild_state.get_full_name(): {
             grandchild_state.get_full_name(): {
                 "value2": "42",
                 "value2": "42",
             },
             },
+            GrandchildState3.get_full_name(): {
+                "computed": "",
+            },
         }
         }
     )
     )
     assert mcall.kwargs["to"] == grandchild_state.get_sid()
     assert mcall.kwargs["to"] == grandchild_state.get_sid()
@@ -1879,8 +1942,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         "private",
         "private",
     ]
     ]
 
 
-    substate_token = f"{token}_{BackgroundTaskState.get_name()}"
-    assert (await mock_app.state_manager.get_state(substate_token)).order == exp_order
+    assert (
+        await mock_app.state_manager.get_state(
+            _substate_key(token, BackgroundTaskState)
+        )
+    ).order == exp_order
 
 
     assert mock_app.event_namespace is not None
     assert mock_app.event_namespace is not None
     emit_mock = mock_app.event_namespace.emit
     emit_mock = mock_app.event_namespace.emit
@@ -1957,8 +2023,11 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
         await task
         await task
     assert not mock_app.background_tasks
     assert not mock_app.background_tasks
 
 
-    substate_token = f"{token}_{BackgroundTaskState.get_name()}"
-    assert (await mock_app.state_manager.get_state(substate_token)).order == [
+    assert (
+        await mock_app.state_manager.get_state(
+            _substate_key(token, BackgroundTaskState)
+        )
+    ).order == [
         "reset",
         "reset",
     ]
     ]
 
 
@@ -2246,7 +2315,7 @@ def test_mutable_copy_vars(mutable_state, copy_func):
 
 
 
 
 def test_duplicate_substate_class(mocker):
 def test_duplicate_substate_class(mocker):
-    mocker.patch("reflex.state.os.environ", {})
+    mocker.patch("reflex.state.is_testing_env", lambda: False)
     with pytest.raises(ValueError):
     with pytest.raises(ValueError):
 
 
         class TestState(BaseState):
         class TestState(BaseState):
@@ -2435,7 +2504,9 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
         expected: Expected delta.
         expected: Expected delta.
         mocker: pytest mock object.
         mocker: pytest mock object.
     """
     """
-    mocker.patch("reflex.state.State.class_subclasses", {test_state})
+    mocker.patch(
+        "reflex.state.State.class_subclasses", {test_state, OnLoadInternalState}
+    )
     app = app_module_mock.app = App(
     app = app_module_mock.app = App(
         state=State, load_events={"index": [test_state.test_handler]}
         state=State, load_events={"index": [test_state.test_handler]}
     )
     )
@@ -2476,7 +2547,9 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
         token: A token.
         token: A token.
         mocker: pytest mock object.
         mocker: pytest mock object.
     """
     """
-    mocker.patch("reflex.state.State.class_subclasses", {OnLoadState})
+    mocker.patch(
+        "reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState}
+    )
     app = app_module_mock.app = App(
     app = app_module_mock.app = App(
         state=State,
         state=State,
         load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
         load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
@@ -2510,3 +2583,120 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
         OnLoadState.get_full_name(): {"num": 2}
         OnLoadState.get_full_name(): {"num": 2}
     }
     }
     assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
     assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
+
+
+@pytest.mark.asyncio
+async def test_get_state(mock_app: rx.App, token: str):
+    """Test that a get_state populates the top level state and delta calculation is correct.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+    mock_app.state_manager.state = mock_app.state = TestState
+
+    # Get instance of ChildState2.
+    test_state = await mock_app.state_manager.get_state(
+        _substate_key(token, ChildState2)
+    )
+    assert isinstance(test_state, TestState)
+    if isinstance(mock_app.state_manager, StateManagerMemory):
+        # All substates are available
+        assert tuple(sorted(test_state.substates)) == (
+            "child_state",
+            "child_state2",
+            "child_state3",
+        )
+    else:
+        # Sibling states are only populated if they have computed vars
+        assert tuple(sorted(test_state.substates)) == ("child_state2", "child_state3")
+
+    # Because ChildState3 has a computed var, it is always dirty, and always populated.
+    assert (
+        test_state.substates["child_state3"].substates["grandchild_state3"].computed
+        == ""
+    )
+
+    # Get the child_state2 directly.
+    child_state2_direct = test_state.get_substate(["child_state2"])
+    child_state2_get_state = await test_state.get_state(ChildState2)
+    # These should be the same object.
+    assert child_state2_direct is child_state2_get_state
+
+    # Get arbitrary GrandchildState.
+    grandchild_state = await child_state2_get_state.get_state(GrandchildState)
+    assert isinstance(grandchild_state, GrandchildState)
+
+    # Now the original root should have all substates populated.
+    assert tuple(sorted(test_state.substates)) == (
+        "child_state",
+        "child_state2",
+        "child_state3",
+    )
+
+    # ChildState should be retrievable
+    child_state_direct = test_state.get_substate(["child_state"])
+    child_state_get_state = await test_state.get_state(ChildState)
+    # These should be the same object.
+    assert child_state_direct is child_state_get_state
+
+    # GrandchildState instance should be the same as the one retrieved from the child_state2.
+    assert grandchild_state is child_state_direct.get_substate(["grandchild_state"])
+    grandchild_state.value2 = "set_value"
+
+    assert test_state.get_delta() == {
+        TestState.get_full_name(): {
+            "sum": 3.14,
+            "upper": "",
+        },
+        GrandchildState.get_full_name(): {
+            "value2": "set_value",
+        },
+        GrandchildState3.get_full_name(): {
+            "computed": "",
+        },
+    }
+
+    # Get a fresh instance
+    new_test_state = await mock_app.state_manager.get_state(
+        _substate_key(token, ChildState2)
+    )
+    assert isinstance(new_test_state, TestState)
+    if isinstance(mock_app.state_manager, StateManagerMemory):
+        # In memory, it's the same instance
+        assert new_test_state is test_state
+        test_state._clean()
+        # All substates are available
+        assert tuple(sorted(new_test_state.substates)) == (
+            "child_state",
+            "child_state2",
+            "child_state3",
+        )
+    else:
+        # With redis, we get a whole new instance
+        assert new_test_state is not test_state
+        # Sibling states are only populated if they have computed vars
+        assert tuple(sorted(new_test_state.substates)) == (
+            "child_state2",
+            "child_state3",
+        )
+
+    # Set a value on child_state2, should update cached var in grandchild_state2
+    child_state2 = new_test_state.get_substate(("child_state2",))
+    child_state2.value = "set_c2_value"
+
+    assert new_test_state.get_delta() == {
+        TestState.get_full_name(): {
+            "sum": 3.14,
+            "upper": "",
+        },
+        ChildState2.get_full_name(): {
+            "value": "set_c2_value",
+        },
+        GrandchildState2.get_full_name(): {
+            "cached": "set_c2_value",
+        },
+        GrandchildState3.get_full_name(): {
+            "computed": "",
+        },
+    }

+ 371 - 0
tests/test_state_tree.py

@@ -0,0 +1,371 @@
+"""Specialized test for a larger state tree."""
+import asyncio
+from typing import Generator
+
+import pytest
+
+import reflex as rx
+from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key
+
+
+class Root(BaseState):
+    """Root of the state tree."""
+
+    root: int
+
+
+class TreeA(Root):
+    """TreeA is a child of Root."""
+
+    a: int
+
+
+class SubA_A(TreeA):
+    """SubA_A is a child of TreeA."""
+
+    sub_a_a: int
+
+
+class SubA_A_A(SubA_A):
+    """SubA_A_A is a child of SubA_A."""
+
+    sub_a_a_a: int
+
+
+class SubA_A_A_A(SubA_A_A):
+    """SubA_A_A_A is a child of SubA_A_A."""
+
+    sub_a_a_a_a: int
+
+
+class SubA_A_A_B(SubA_A_A):
+    """SubA_A_A_B is a child of SubA_A_A."""
+
+    @rx.cached_var
+    def sub_a_a_a_cached(self) -> int:
+        """A cached var.
+
+        Returns:
+            The value of sub_a_a_a + 1
+        """
+        return self.sub_a_a_a + 1
+
+
+class SubA_A_A_C(SubA_A_A):
+    """SubA_A_A_C is a child of SubA_A_A."""
+
+    sub_a_a_a_c: int
+
+
+class SubA_A_B(SubA_A):
+    """SubA_A_B is a child of SubA_A."""
+
+    sub_a_a_b: int
+
+
+class SubA_B(TreeA):
+    """SubA_B is a child of TreeA."""
+
+    sub_a_b: int
+
+
+class TreeB(Root):
+    """TreeB is a child of Root."""
+
+    b: int
+
+
+class SubB_A(TreeB):
+    """SubB_A is a child of TreeB."""
+
+    sub_b_a: int
+
+
+class SubB_B(TreeB):
+    """SubB_B is a child of TreeB."""
+
+    sub_b_b: int
+
+
+class SubB_C(TreeB):
+    """SubB_C is a child of TreeB."""
+
+    sub_b_c: int
+
+
+class SubB_C_A(SubB_C):
+    """SubB_C_A is a child of SubB_C."""
+
+    sub_b_c_a: int
+
+
+class TreeC(Root):
+    """TreeC is a child of Root."""
+
+    c: int
+
+
+class SubC_A(TreeC):
+    """SubC_A is a child of TreeC."""
+
+    sub_c_a: int
+
+
+class TreeD(Root):
+    """TreeD is a child of Root."""
+
+    d: int
+
+    @rx.var
+    def d_var(self) -> int:
+        """A computed var.
+
+        Returns:
+            The value of d + 1
+        """
+        return self.d + 1
+
+
+class TreeE(Root):
+    """TreeE is a child of Root."""
+
+    e: int
+
+
+class SubE_A(TreeE):
+    """SubE_A is a child of TreeE."""
+
+    sub_e_a: int
+
+
+class SubE_A_A(SubE_A):
+    """SubE_A_A is a child of SubE_A."""
+
+    sub_e_a_a: int
+
+
+class SubE_A_A_A(SubE_A_A):
+    """SubE_A_A_A is a child of SubE_A_A."""
+
+    sub_e_a_a_a: int
+
+
+class SubE_A_A_A_A(SubE_A_A_A):
+    """SubE_A_A_A_A is a child of SubE_A_A_A."""
+
+    sub_e_a_a_a_a: int
+
+    @rx.var
+    def sub_e_a_a_a_a_var(self) -> int:
+        """A computed var.
+
+        Returns:
+            The value of sub_e_a_a_a_a + 1
+        """
+        return self.sub_e_a_a_a + 1
+
+
+class SubE_A_A_A_B(SubE_A_A_A):
+    """SubE_A_A_A_B is a child of SubE_A_A_A."""
+
+    sub_e_a_a_a_b: int
+
+
+class SubE_A_A_A_C(SubE_A_A_A):
+    """SubE_A_A_A_C is a child of SubE_A_A_A."""
+
+    sub_e_a_a_a_c: int
+
+
+class SubE_A_A_A_D(SubE_A_A_A):
+    """SubE_A_A_A_D is a child of SubE_A_A_A."""
+
+    sub_e_a_a_a_d: int
+
+    @rx.cached_var
+    def sub_e_a_a_a_d_var(self) -> int:
+        """A computed var.
+
+        Returns:
+            The value of sub_e_a_a_a_a + 1
+        """
+        return self.sub_e_a_a_a + 1
+
+
+ALWAYS_COMPUTED_VARS = {
+    TreeD.get_full_name(): {"d_var": 1},
+    SubE_A_A_A_A.get_full_name(): {"sub_e_a_a_a_a_var": 1},
+}
+
+ALWAYS_COMPUTED_DICT_KEYS = [
+    Root.get_full_name(),
+    TreeD.get_full_name(),
+    TreeE.get_full_name(),
+    SubE_A.get_full_name(),
+    SubE_A_A.get_full_name(),
+    SubE_A_A_A.get_full_name(),
+    SubE_A_A_A_A.get_full_name(),
+    SubE_A_A_A_D.get_full_name(),
+]
+
+
+@pytest.fixture(scope="function")
+def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]:
+    """Instance of state manager for redis only.
+
+    Args:
+        app_module_mock: The app module mock fixture.
+
+    Yields:
+        A state manager instance
+    """
+    app_module_mock.app = rx.App(state=Root)
+    state_manager = app_module_mock.app.state_manager
+
+    if not isinstance(state_manager, StateManagerRedis):
+        pytest.skip("Test requires redis")
+
+    yield state_manager
+
+    asyncio.get_event_loop().run_until_complete(state_manager.close())
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    ("substate_cls", "exp_root_substates", "exp_root_dict_keys"),
+    [
+        (
+            Root,
+            ["tree_a", "tree_b", "tree_c", "tree_d", "tree_e"],
+            [
+                TreeA.get_full_name(),
+                SubA_A.get_full_name(),
+                SubA_A_A.get_full_name(),
+                SubA_A_A_A.get_full_name(),
+                SubA_A_A_B.get_full_name(),
+                SubA_A_A_C.get_full_name(),
+                SubA_A_B.get_full_name(),
+                SubA_B.get_full_name(),
+                TreeB.get_full_name(),
+                SubB_A.get_full_name(),
+                SubB_B.get_full_name(),
+                SubB_C.get_full_name(),
+                SubB_C_A.get_full_name(),
+                TreeC.get_full_name(),
+                SubC_A.get_full_name(),
+                SubE_A_A_A_B.get_full_name(),
+                SubE_A_A_A_C.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeA,
+            ("tree_a", "tree_d", "tree_e"),
+            [
+                TreeA.get_full_name(),
+                SubA_A.get_full_name(),
+                SubA_A_A.get_full_name(),
+                SubA_A_A_A.get_full_name(),
+                SubA_A_A_B.get_full_name(),
+                SubA_A_A_C.get_full_name(),
+                SubA_A_B.get_full_name(),
+                SubA_B.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            SubA_A_A_A,
+            ["tree_a", "tree_d", "tree_e"],
+            [
+                TreeA.get_full_name(),
+                SubA_A.get_full_name(),
+                SubA_A_A.get_full_name(),
+                SubA_A_A_A.get_full_name(),
+                SubA_A_A_B.get_full_name(),  # Cached var dep
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeB,
+            ["tree_b", "tree_d", "tree_e"],
+            [
+                TreeB.get_full_name(),
+                SubB_A.get_full_name(),
+                SubB_B.get_full_name(),
+                SubB_C.get_full_name(),
+                SubB_C_A.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            SubB_B,
+            ["tree_b", "tree_d", "tree_e"],
+            [
+                TreeB.get_full_name(),
+                SubB_B.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            SubB_C_A,
+            ["tree_b", "tree_d", "tree_e"],
+            [
+                TreeB.get_full_name(),
+                SubB_C.get_full_name(),
+                SubB_C_A.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeC,
+            ["tree_c", "tree_d", "tree_e"],
+            [
+                TreeC.get_full_name(),
+                SubC_A.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeD,
+            ["tree_d", "tree_e"],
+            [
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeE,
+            ["tree_d", "tree_e"],
+            [
+                # Extra siblings of computed var included now.
+                SubE_A_A_A_B.get_full_name(),
+                SubE_A_A_A_C.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+    ],
+)
+async def test_get_state_tree(
+    state_manager_redis,
+    token,
+    substate_cls,
+    exp_root_substates,
+    exp_root_dict_keys,
+):
+    """Test getting state trees and assert on which branches are retrieved.
+
+    Args:
+        state_manager_redis: The state manager redis fixture.
+        token: The token fixture.
+        substate_cls: The substate class to retrieve.
+        exp_root_substates: The expected substates of the root state.
+        exp_root_dict_keys: The expected keys of the root state dict.
+    """
+    state = await state_manager_redis.get_state(_substate_key(token, substate_cls))
+    assert isinstance(state, Root)
+    assert sorted(state.substates) == sorted(exp_root_substates)
+
+    # Only computed vars should be returned
+    assert state.get_delta() == ALWAYS_COMPUTED_VARS
+
+    # All of TreeA, TreeD, and TreeE substates should be in the dict
+    assert sorted(state.dict()) == sorted(exp_root_dict_keys)

+ 8 - 2
tests/utils/test_format.py

@@ -13,8 +13,11 @@ from reflex.vars import BaseVar, Var
 from tests.test_state import (
 from tests.test_state import (
     ChildState,
     ChildState,
     ChildState2,
     ChildState2,
+    ChildState3,
     DateTimeState,
     DateTimeState,
     GrandchildState,
     GrandchildState,
+    GrandchildState2,
+    GrandchildState3,
     TestState,
     TestState,
 )
 )
 
 
@@ -649,7 +652,7 @@ formatted_router = {
     "input, output",
     "input, output",
     [
     [
         (
         (
-            TestState().dict(),  # type: ignore
+            TestState(_reflex_internal_init=True).dict(),  # type: ignore
             {
             {
                 TestState.get_full_name(): {
                 TestState.get_full_name(): {
                     "array": [1, 2, 3.14],
                     "array": [1, 2, 3.14],
@@ -674,11 +677,14 @@ formatted_router = {
                     "value": "",
                     "value": "",
                 },
                 },
                 ChildState2.get_full_name(): {"value": ""},
                 ChildState2.get_full_name(): {"value": ""},
+                ChildState3.get_full_name(): {"value": ""},
                 GrandchildState.get_full_name(): {"value2": ""},
                 GrandchildState.get_full_name(): {"value2": ""},
+                GrandchildState2.get_full_name(): {"cached": ""},
+                GrandchildState3.get_full_name(): {"computed": ""},
             },
             },
         ),
         ),
         (
         (
-            DateTimeState().dict(),
+            DateTimeState(_reflex_internal_init=True).dict(),  # type: ignore
             {
             {
                 DateTimeState.get_full_name(): {
                 DateTimeState.get_full_name(): {
                     "d": "1989-11-09",
                     "d": "1989-11-09",