瀏覽代碼

[REF-1988] API to Get instance of Arbitrary State class (#2678)

* WiP get_state

* Refactor get_state fast path

Rudimentary protection for state instance access from a background task
(StateProxy)

* retain dirty substate marking per `_mark_dirty` call to avoid test changes

* Find common ancestor by part instead of by character

Fix StateProxy for substates and parent_state attributes (have to handle in
__getattr__, not property)

Fix type annotation for `get_state`

* test_state: workflow test for `get_state` functionality

* Do not reset _always_dirty_substates when adding vars

Reset the substate tracking only when the class is instantiated.

* test_state_tree: test substate access in a larger state tree

Ensure that `get_state` returns the proper "branch" of the state tree depending
on what substate is requested.

* test_format: fixup broken tests from adding substates of TestState

* Fix flaky integration tests with more polling

* AppHarness: reset _always_dirty_substates on rx.State

* RuntimeError unless State is instantiated with _reflex_internal_init=True

Avoid user errors trying to directly instantiate State classes

* Helper functions for _substate_key and _split_substate_key

Unify the implementation of generating and decoding the token + state name
format used for redis state sharding.

* StateManagerRedis: use create_task in get_state and set_state

read and write substates concurrently (allow redis to shine)

* test_state_inheritance: use polling cuz life too short for flaky tests

kthnxbai :heart:

* Move _is_testing_env to reflex.utils.exec.is_testing_env

Reuse the code in app.py

* Break up `BaseState.get_state` and friends into separate methods

* Add test case for pre-fetching cached var dependency

* Move on_load_internal and update_vars_internal to substates

Avoid loading the entire state tree to process these common internal events. If
the state tree is very large, this allow page navigation to occur more quickly.

Pre-fetch substates that contain cached vars, as they may need to be recomputed
if certain vars change.

* Do not copy ROUTER_DATA into all substates.

This is a waste of time and memory, and can be handled via a special case in
__getattribute__

* Track whether State instance _was_touched

Avoid wasting time serializing states that have no modifications

* Do not persist states in `StateManagerRedis.get_state`

Wait until the state is actually modified, and then persist it as part of `set_state`.

Factor out common logic into helper methods for readability and to reduce
duplication of common logic.

To avoid having to recursively call `get_state`, which would require persisting
the instance and then getting it again, some of the initialization logic
regarding parent_state and substates is duplicated when creating a new
instance. This is for performance reasons.

* Remove stray print()

* context.js.jinja2: fix check for empty local storage / cookie vars

* Add comments for onLoadInternalEvent and initialEvents

* nit: typo

* split _get_was_touched into _update_was_touched

Improve clarity in cases where _get_was_touched was being called for its side
effects only.

* Remove extraneous information from incorrect State instantiation error

* Update missing redis exception message
Masen Furer 1 年之前
父節點
當前提交
deae662e2a

+ 2 - 2
integration/test_client_storage.py

@@ -518,8 +518,8 @@ async def test_client_side_state(
     set_sub("l6", "l6 value")
     l5 = driver.find_element(By.ID, "l5")
     l6 = driver.find_element(By.ID, "l6")
+    assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert l5.text == "l5 value"
-    assert l6.text == "l6 value"
 
     # Switch back to main window.
     driver.switch_to.window(main_tab)
@@ -527,8 +527,8 @@ async def test_client_side_state(
     # The values should have updated automatically.
     l5 = driver.find_element(By.ID, "l5")
     l6 = driver.find_element(By.ID, "l6")
+    assert AppHarness._poll_for(lambda: l6.text == "l6 value")
     assert l5.text == "l5 value"
-    assert l6.text == "l6 value"
 
     # clear the cookie jar and local storage, ensure state reset to default
     driver.delete_all_cookies()

+ 19 - 4
integration/test_state_inheritance.py

@@ -1,14 +1,29 @@
 """Test state inheritance."""
 
-import time
+from contextlib import suppress
 from typing import Generator
 
 import pytest
+from selenium.common.exceptions import NoAlertPresentException
+from selenium.webdriver.common.alert import Alert
 from selenium.webdriver.common.by import By
 
 from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
 
 
+def get_alert_or_none(driver: WebDriver) -> Alert | None:
+    """Switch to an alert if present.
+
+    Args:
+        driver: WebDriver instance.
+
+    Returns:
+        The alert if present, otherwise None.
+    """
+    with suppress(NoAlertPresentException):
+        return driver.switch_to.alert
+
+
 def raises_alert(driver: WebDriver, element: str) -> None:
     """Click an element and check that an alert is raised.
 
@@ -18,8 +33,8 @@ def raises_alert(driver: WebDriver, element: str) -> None:
     """
     btn = driver.find_element(By.ID, element)
     btn.click()
-    time.sleep(0.2)  # wait for the alert to appear
-    alert = driver.switch_to.alert
+    alert = AppHarness._poll_for(lambda: get_alert_or_none(driver))
+    assert isinstance(alert, Alert)
     assert alert.text == "clicked"
     alert.accept()
 
@@ -355,7 +370,7 @@ def test_state_inheritance(
     child3_other_mixin_btn = driver.find_element(By.ID, "child3-other-mixin-btn")
     child3_other_mixin_btn.click()
     child2_other_mixin_value = state_inheritance.poll_for_content(
-        child2_other_mixin, exp_not_equal="other_mixin"
+        child2_other_mixin, exp_not_equal="Child2.clicked.1"
     )
     child2_computed_mixin_value = state_inheritance.poll_for_content(
         child2_computed_other_mixin, exp_not_equal="other_mixin"

+ 24 - 4
reflex/.templates/jinja/web/utils/context.js.jinja2

@@ -25,11 +25,31 @@ export const clientStorage = {}
 
 {% if state_name %}
 export const state_name = "{{state_name}}"
-export const onLoadInternalEvent = () => [
-    Event('{{state_name}}.{{const.update_vars_internal}}', {vars: hydrateClientStorage(clientStorage)}),
-    Event('{{state_name}}.{{const.on_load_internal}}')
-]
 
+// Theses events are triggered on initial load and each page navigation.
+export const onLoadInternalEvent = () => {
+    const internal_events = [];
+
+    // Get tracked cookie and local storage vars to send to the backend.
+    const client_storage_vars = hydrateClientStorage(clientStorage);
+    // But only send the vars if any are actually set in the browser.
+    if (client_storage_vars && Object.keys(client_storage_vars).length !== 0) {
+        internal_events.push(
+            Event(
+                '{{state_name}}.{{const.update_vars_internal}}',
+                {vars: client_storage_vars},
+            ),
+        );
+    }
+
+    // `on_load_internal` triggers the correct on_load event(s) for the current page.
+    // If the page does not define any on_load event, this will just set `is_hydrated = true`.
+    internal_events.push(Event('{{state_name}}.{{const.on_load_internal}}'));
+
+    return internal_events;
+}
+
+// The following events are sent when the websocket connects or reconnects.
 export const initialEvents = () => [
     Event('{{state_name}}.{{const.hydrate}}'),
     ...onLoadInternalEvent()

+ 1 - 1
reflex/.templates/web/utils/state.js

@@ -587,7 +587,7 @@ export const useEventLoop = (
       if (storage_to_state_map[e.key]) {
         const vars = {}
         vars[storage_to_state_map[e.key]] = e.newValue
-        const event = Event(`${state_name}.update_vars_internal`, {vars: vars})
+        const event = Event(`${state_name}.update_vars_internal_state.update_vars_internal`, {vars: vars})
         addEvents([event], e);
       }
     };

+ 6 - 4
reflex/app.py

@@ -69,9 +69,11 @@ from reflex.state import (
     State,
     StateManager,
     StateUpdate,
+    _substate_key,
     code_uses_state_contexts,
 )
 from reflex.utils import console, exceptions, format, prerequisites, types
+from reflex.utils.exec import is_testing_env
 from reflex.utils.imports import ImportVar
 
 # Define custom types.
@@ -159,10 +161,9 @@ class App(Base):
             )
         super().__init__(*args, **kwargs)
         state_subclasses = BaseState.__subclasses__()
-        is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
 
         # Special case to allow test cases have multiple subclasses of rx.BaseState.
-        if not is_testing_env:
+        if not is_testing_env():
             # Only one Base State class is allowed.
             if len(state_subclasses) > 1:
                 raise ValueError(
@@ -176,7 +177,8 @@ class App(Base):
                     deprecation_version="0.3.5",
                     removal_version="0.5.0",
                 )
-            if len(State.class_subclasses) > 0:
+            # 2 substates are built-in and not considered when determining if app is stateless.
+            if len(State.class_subclasses) > 2:
                 self.state = State
         # Get the config
         config = get_config()
@@ -1002,7 +1004,7 @@ def upload(app: App):
             )
 
         # Get the state for the session.
-        substate_token = token + "_" + handler.rpartition(".")[0]
+        substate_token = _substate_key(token, handler.rpartition(".")[0])
         state = await app.state_manager.get_state(substate_token)
 
         # get the current session ID

+ 2 - 2
reflex/compiler/utils.py

@@ -138,12 +138,12 @@ def compile_state(state: Type[BaseState]) -> dict:
         A dictionary of the compiled state.
     """
     try:
-        initial_state = state().dict(initial=True)
+        initial_state = state(_reflex_internal_init=True).dict(initial=True)
     except Exception as e:
         console.warn(
             f"Failed to compile initial state with computed vars, excluding them: {e}"
         )
-        initial_state = state().dict(include_computed=False)
+        initial_state = state(_reflex_internal_init=True).dict(include_computed=False)
     return format.format_state(initial_state)
 
 

+ 2 - 2
reflex/constants/compiler.py

@@ -59,9 +59,9 @@ class CompileVars(SimpleNamespace):
     # The name of the function for converting a dict to an event.
     TO_EVENT = "Event"
     # The name of the internal on_load event.
-    ON_LOAD_INTERNAL = "on_load_internal"
+    ON_LOAD_INTERNAL = "on_load_internal_state.on_load_internal"
     # The name of the internal event to update generic state vars.
-    UPDATE_VARS_INTERNAL = "update_vars_internal"
+    UPDATE_VARS_INTERNAL = "update_vars_internal_state.update_vars_internal"
 
 
 class PageNames(SimpleNamespace):

+ 524 - 105
reflex/state.py

@@ -8,7 +8,6 @@ import copy
 import functools
 import inspect
 import json
-import os
 import traceback
 import urllib.parse
 import uuid
@@ -45,6 +44,7 @@ from reflex.event import (
 )
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
+from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import SerializedType, serialize, serializer
 from reflex.vars import BaseVar, ComputedVar, Var, computed_var
 
@@ -151,9 +151,45 @@ RESERVED_BACKEND_VAR_NAMES = {
     "_substate_var_dependencies",
     "_always_dirty_computed_vars",
     "_always_dirty_substates",
+    "_was_touched",
 }
 
 
+def _substate_key(
+    token: str,
+    state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
+) -> str:
+    """Get the substate key.
+
+    Args:
+        token: The token of the state.
+        state_cls_or_name: The state class/instance or name or sequence of name parts.
+
+    Returns:
+        The substate key.
+    """
+    if isinstance(state_cls_or_name, BaseState) or (
+        isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
+    ):
+        state_cls_or_name = state_cls_or_name.get_full_name()
+    elif isinstance(state_cls_or_name, (list, tuple)):
+        state_cls_or_name = ".".join(state_cls_or_name)
+    return f"{token}_{state_cls_or_name}"
+
+
+def _split_substate_key(substate_key: str) -> tuple[str, str]:
+    """Split the substate key into token and state name.
+
+    Args:
+        substate_key: The substate key.
+
+    Returns:
+        Tuple of token and state name.
+    """
+    token, _, state_name = substate_key.partition("_")
+    return token, state_name
+
+
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
 
@@ -214,29 +250,46 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # The router data for the current page
     router: RouterData = RouterData()
 
+    # Whether the state has ever been touched since instantiation.
+    _was_touched: bool = False
+
     def __init__(
         self,
         *args,
         parent_state: BaseState | None = None,
         init_substates: bool = True,
+        _reflex_internal_init: bool = False,
         **kwargs,
     ):
         """Initialize the state.
 
+        DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead.
+
         Args:
             *args: The args to pass to the Pydantic init method.
             parent_state: The parent state.
             init_substates: Whether to initialize the substates in this instance.
+            _reflex_internal_init: A flag to indicate that the state is being initialized by the framework.
             **kwargs: The kwargs to pass to the Pydantic init method.
 
+        Raises:
+            RuntimeError: If the state is instantiated directly by end user.
         """
+        if not _reflex_internal_init and not is_testing_env():
+            raise RuntimeError(
+                "State classes should not be instantiated directly in a Reflex app. "
+                "See https://reflex.dev/docs/state for further information."
+            )
         kwargs["parent_state"] = parent_state
         super().__init__(*args, **kwargs)
 
         # Setup the substates (for memory state manager only).
         if init_substates:
             for substate in self.get_substates():
-                self.substates[substate.get_name()] = substate(parent_state=self)
+                self.substates[substate.get_name()] = substate(
+                    parent_state=self,
+                    _reflex_internal_init=True,
+                )
         # Convert the event handlers to functions.
         self._init_event_handlers()
 
@@ -287,7 +340,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         Raises:
             ValueError: If a substate class shadows another.
         """
-        is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
         super().__init_subclass__(**kwargs)
         # Event handlers should not shadow builtin state methods.
         cls._check_overridden_methods()
@@ -295,6 +347,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Reset subclass tracking for this class.
         cls.class_subclasses = set()
 
+        # Reset dirty substate tracking for this class.
+        cls._always_dirty_substates = set()
+
         # Get the parent vars.
         parent_state = cls.get_parent_state()
         if parent_state is not None:
@@ -303,7 +358,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
             # Check if another substate class with the same name has already been defined.
             if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
-                if is_testing_env:
+                if is_testing_env():
                     # Clear existing subclass with same name when app is reloaded via
                     # utils.prerequisites.get_app(reload=True)
                     parent_state.class_subclasses = set(
@@ -325,6 +380,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             name: value
             for name, value in cls.__dict__.items()
             if types.is_backend_variable(name, cls)
+            and name not in RESERVED_BACKEND_VAR_NAMES
             and name not in cls.inherited_backend_vars
             and not isinstance(value, FunctionType)
         }
@@ -484,7 +540,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         )
 
         # Any substate containing a ComputedVar with cache=False always needs to be recomputed
-        cls._always_dirty_substates = set()
         if cls._always_dirty_computed_vars:
             # Tell parent classes that this substate has always dirty computed vars
             state_name = cls.get_name()
@@ -923,8 +978,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             **super().__getattribute__("inherited_vars"),
             **super().__getattribute__("inherited_backend_vars"),
         }
-        if name in inherited_vars:
-            return getattr(super().__getattribute__("parent_state"), name)
+
+        # For now, handle router_data updates as a special case.
+        if name in inherited_vars or name == constants.ROUTER_DATA:
+            parent_state = super().__getattribute__("parent_state")
+            if parent_state is not None:
+                return getattr(parent_state, name)
 
         backend_vars = super().__getattribute__("_backend_vars")
         if name in backend_vars:
@@ -980,9 +1039,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         if name == constants.ROUTER_DATA:
             self.dirty_vars.add(name)
             self._mark_dirty()
-            # propagate router_data updates down the state tree
-            for substate in self.substates.values():
-                setattr(substate, name, value)
 
     def reset(self):
         """Reset all the base vars to their default values."""
@@ -1036,6 +1092,170 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             raise ValueError(f"Invalid path: {path}")
         return self.substates[path[0]].get_substate(path[1:])
 
+    @classmethod
+    def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
+        """Find the name of the nearest common ancestor shared by this and the other state.
+
+        Args:
+            other: The other state.
+
+        Returns:
+            Full name of the nearest common ancestor.
+        """
+        common_ancestor_parts = []
+        for part1, part2 in zip(
+            cls.get_full_name().split("."),
+            other.get_full_name().split("."),
+        ):
+            if part1 != part2:
+                break
+            common_ancestor_parts.append(part1)
+        return ".".join(common_ancestor_parts)
+
+    @classmethod
+    def _determine_missing_parent_states(
+        cls, target_state_cls: Type[BaseState]
+    ) -> tuple[str, list[str]]:
+        """Determine the missing parent states between the target_state_cls and common ancestor of this state.
+
+        Args:
+            target_state_cls: The class of the state to find missing parent states for.
+
+        Returns:
+            The name of the common ancestor and the list of missing parent states.
+        """
+        common_ancestor_name = cls._get_common_ancestor(target_state_cls)
+        common_ancestor_parts = common_ancestor_name.split(".")
+        target_state_parts = tuple(target_state_cls.get_full_name().split("."))
+        relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
+
+        # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
+        fetch_parent_states = [common_ancestor_name]
+        for ix, relative_parent_state_name in enumerate(relative_target_state_parts):
+            fetch_parent_states.append(
+                ".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name])
+            )
+
+        return common_ancestor_name, fetch_parent_states[1:-1]
+
+    def _get_parent_states(self) -> list[tuple[str, BaseState]]:
+        """Get all parent state instances up to the root of the state tree.
+
+        Returns:
+            A list of tuples containing the name and the instance of each parent state.
+        """
+        parent_states_with_name = []
+        parent_state = self
+        while parent_state.parent_state is not None:
+            parent_state = parent_state.parent_state
+            parent_states_with_name.append((parent_state.get_full_name(), parent_state))
+        return parent_states_with_name
+
+    async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
+        """Populate substates in the tree between the target_state_cls and common ancestor of this state.
+
+        Args:
+            target_state_cls: The class of the state to populate parent states for.
+
+        Returns:
+            The parent state instance of target_state_cls.
+
+        Raises:
+            RuntimeError: If redis is not used in this backend process.
+        """
+        state_manager = get_state_manager()
+        if not isinstance(state_manager, StateManagerRedis):
+            raise RuntimeError(
+                f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
+                "(All states should already be available -- this is likely a bug).",
+            )
+
+        # Find the missing parent states up to the common ancestor.
+        (
+            common_ancestor_name,
+            missing_parent_states,
+        ) = self._determine_missing_parent_states(target_state_cls)
+
+        # Fetch all missing parent states and link them up to the common ancestor.
+        parent_states_by_name = dict(self._get_parent_states())
+        parent_state = parent_states_by_name[common_ancestor_name]
+        for parent_state_name in missing_parent_states:
+            parent_state = await state_manager.get_state(
+                token=_substate_key(
+                    self.router.session.client_token, parent_state_name
+                ),
+                top_level=False,
+                get_substates=False,
+                parent_state=parent_state,
+            )
+
+        # Return the direct parent of target_state_cls for subsequent linking.
+        return parent_state
+
+    def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
+        """Get a state instance from the cache.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The instance of state_cls associated with this state's client_token.
+        """
+        if self.parent_state is None:
+            root_state = self
+        else:
+            root_state = self._get_parent_states()[-1][1]
+        return root_state.get_substate(state_cls.get_full_name().split("."))
+
+    async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
+        """Get a state instance from redis.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The instance of state_cls associated with this state's client_token.
+
+        Raises:
+            RuntimeError: If redis is not used in this backend process.
+        """
+        # Fetch all missing parent states from redis.
+        parent_state_of_state_cls = await self._populate_parent_states(state_cls)
+
+        # Then get the target state and all its substates.
+        state_manager = get_state_manager()
+        if not isinstance(state_manager, StateManagerRedis):
+            raise RuntimeError(
+                f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
+                "(All states should already be available -- this is likely a bug).",
+            )
+        return await state_manager.get_state(
+            token=_substate_key(self.router.session.client_token, state_cls),
+            top_level=False,
+            get_substates=True,
+            parent_state=parent_state_of_state_cls,
+        )
+
+    async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
+        """Get an instance of the state associated with this token.
+
+        Allows for arbitrary access to sibling states from within an event handler.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The instance of state_cls associated with this state's client_token.
+        """
+        # Fast case - if this state instance is already cached, get_substate from root state.
+        try:
+            return self._get_state_from_cache(state_cls)
+        except ValueError:
+            pass
+
+        # Slow case - fetch missing parent states from redis.
+        return await self._get_state_from_redis(state_cls)
+
     def _get_event_handler(
         self, event: Event
     ) -> tuple[BaseState | StateProxy, EventHandler]:
@@ -1238,6 +1458,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             for cvar in self._computed_var_dependencies[dirty_var]
         )
 
+    @classmethod
+    def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
+        """Determine substates which could be affected by dirty vars in this state.
+
+        Returns:
+            Set of State classes that may need to be fetched to recalc computed vars.
+        """
+        # _always_dirty_substates need to be fetched to recalc computed vars.
+        fetch_substates = set(
+            cls.get_class_substate(tuple(substate_name.split(".")))
+            for substate_name in cls._always_dirty_substates
+        )
+        # Substates with cached vars also need to be fetched.
+        for dependent_substates in cls._substate_var_dependencies.values():
+            fetch_substates.update(
+                set(
+                    cls.get_class_substate(tuple(substate_name.split(".")))
+                    for substate_name in dependent_substates
+                )
+            )
+        return fetch_substates
+
     def get_delta(self) -> Delta:
         """Get the delta for the state.
 
@@ -1269,8 +1511,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Recursively find the substate deltas.
         substates = self.substates
         for substate in self.dirty_substates.union(self._always_dirty_substates):
-            if substate not in substates:
-                continue  # substate not loaded at this time, no delta
             delta.update(substates[substate].get_delta())
 
         # Format the delta.
@@ -1292,20 +1532,45 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # have to mark computed vars dirty to allow access to newly computed
         # values within the same ComputedVar function
         self._mark_dirty_computed_vars()
+        self._mark_dirty_substates()
 
-        # Propagate dirty var / computed var status into substates
+    def _mark_dirty_substates(self):
+        """Propagate dirty var / computed var status into substates."""
         substates = self.substates
         for var in self.dirty_vars:
             for substate_name in self._substate_var_dependencies[var]:
                 self.dirty_substates.add(substate_name)
-                if substate_name not in substates:
-                    continue
                 substate = substates[substate_name]
                 substate.dirty_vars.add(var)
                 substate._mark_dirty()
 
+    def _update_was_touched(self):
+        """Update the _was_touched flag based on dirty_vars."""
+        if self.dirty_vars and not self._was_touched:
+            for var in self.dirty_vars:
+                if var in self.base_vars or var in self._backend_vars:
+                    self._was_touched = True
+                    break
+
+    def _get_was_touched(self) -> bool:
+        """Check current dirty_vars and flag to determine if state instance was modified.
+
+        If any dirty vars belong to this state, mark _was_touched.
+
+        This flag determines whether this state instance should be persisted to redis.
+
+        Returns:
+            Whether this state instance was ever modified.
+        """
+        # Ensure the flag is up to date based on the current dirty_vars
+        self._update_was_touched()
+        return self._was_touched
+
     def _clean(self):
         """Reset the dirty vars."""
+        # Update touched status before cleaning dirty_vars.
+        self._update_was_touched()
+
         # Recursively clean the substates.
         for substate in self.dirty_substates:
             if substate not in self.substates:
@@ -1422,6 +1687,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         state["__dict__"] = state["__dict__"].copy()
         state["__dict__"]["parent_state"] = None
         state["__dict__"]["substates"] = {}
+        state["__dict__"].pop("_was_touched", None)
         return state
 
 
@@ -1431,6 +1697,35 @@ class State(BaseState):
     # The hydrated bool.
     is_hydrated: bool = False
 
+
+class UpdateVarsInternalState(State):
+    """Substate for handling internal state var updates."""
+
+    async def update_vars_internal(self, vars: dict[str, Any]) -> None:
+        """Apply updates to fully qualified state vars.
+
+        The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
+        and each value will be set on the appropriate substate instance.
+
+        This function is primarily used to apply cookie and local storage
+        updates from the frontend to the appropriate substate.
+
+        Args:
+            vars: The fully qualified vars and values to update.
+        """
+        for var, value in vars.items():
+            state_name, _, var_name = var.rpartition(".")
+            var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
+            var_state = await self.get_state(var_state_cls)
+            setattr(var_state, var_name, value)
+
+
+class OnLoadInternalState(State):
+    """Substate for handling on_load event enumeration.
+
+    This is a separate substate to avoid deserializing the entire state tree for every page navigation.
+    """
+
     def on_load_internal(self) -> list[Event | EventSpec] | None:
         """Queue on_load handlers for the current page.
 
@@ -1442,6 +1737,9 @@ class State(BaseState):
         load_events = app.get_load_events(self.router.page.path)
         if not load_events and self.is_hydrated:
             return  # Fast path for page-to-page navigation
+        if not load_events:
+            self.is_hydrated = True
+            return  # Fast path for initial hydrate with no on_load events defined.
         self.is_hydrated = False
         return [
             *fix_events(
@@ -1449,26 +1747,9 @@ class State(BaseState):
                 self.router.session.client_token,
                 router_data=self.router_data,
             ),
-            type(self).set_is_hydrated(True),  # type: ignore
+            State.set_is_hydrated(True),  # type: ignore
         ]
 
-    def update_vars_internal(self, vars: dict[str, Any]) -> None:
-        """Apply updates to fully qualified state vars.
-
-        The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
-        and each value will be set on the appropriate substate instance.
-
-        This function is primarily used to apply cookie and local storage
-        updates from the frontend to the appropriate substate.
-
-        Args:
-            vars: The fully qualified vars and values to update.
-        """
-        for var, value in vars.items():
-            state_name, _, var_name = var.rpartition(".")
-            var_state = self.get_substate(state_name.split("."))
-            setattr(var_state, var_name, value)
-
 
 class StateProxy(wrapt.ObjectProxy):
     """Proxy of a state instance to control mutability of vars for a background task.
@@ -1522,9 +1803,10 @@ class StateProxy(wrapt.ObjectProxy):
             This StateProxy instance in mutable mode.
         """
         self._self_actx = self._self_app.modify_state(
-            self.__wrapped__.router.session.client_token
-            + "_"
-            + ".".join(self._self_substate_path)
+            token=_substate_key(
+                self.__wrapped__.router.session.client_token,
+                self._self_substate_path,
+            )
         )
         mutable_state = await self._self_actx.__aenter__()
         super().__setattr__(
@@ -1574,7 +1856,15 @@ class StateProxy(wrapt.ObjectProxy):
 
         Returns:
             The value of the attribute.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
         """
+        if name in ["substates", "parent_state"] and not self._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
         value = super().__getattr__(name)
         if not name.startswith("_self_") and isinstance(value, MutableProxy):
             # ensure mutations to these containers are blocked unless proxy is _mutable
@@ -1622,6 +1912,60 @@ class StateProxy(wrapt.ObjectProxy):
             "manager. Use `async with self` to modify state."
         )
 
+    def get_substate(self, path: Sequence[str]) -> BaseState:
+        """Only allow substate access with lock held.
+
+        Args:
+            path: The path to the substate.
+
+        Returns:
+            The substate.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not self._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return self.__wrapped__.get_substate(path)
+
+    async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
+        """Get an instance of the state associated with this token.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The state.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not self._self_mutable:
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return await self.__wrapped__.get_state(state_cls)
+
+    def _as_state_update(self, *args, **kwargs) -> StateUpdate:
+        """Temporarily allow mutability to access parent_state.
+
+        Args:
+            *args: The args to pass to the underlying state instance.
+            **kwargs: The kwargs to pass to the underlying state instance.
+
+        Returns:
+            The state update.
+        """
+        self._self_mutable = True
+        try:
+            return self.__wrapped__._as_state_update(*args, **kwargs)
+        finally:
+            self._self_mutable = False
+
 
 class StateUpdate(Base):
     """A state update sent to the frontend."""
@@ -1722,9 +2066,9 @@ class StateManagerMemory(StateManager):
             The state for the token.
         """
         # Memory state manager ignores the substate suffix and always returns the top-level state.
-        token = token.partition("_")[0]
+        token = _split_substate_key(token)[0]
         if token not in self.states:
-            self.states[token] = self.state()
+            self.states[token] = self.state(_reflex_internal_init=True)
         return self.states[token]
 
     async def set_state(self, token: str, state: BaseState):
@@ -1747,7 +2091,7 @@ class StateManagerMemory(StateManager):
             The state for the token.
         """
         # Memory state manager ignores the substate suffix and always returns the top-level state.
-        token = token.partition("_")[0]
+        token = _split_substate_key(token)[0]
         if token not in self._states_locks:
             async with self._state_manager_lock:
                 if token not in self._states_locks:
@@ -1787,6 +2131,81 @@ class StateManagerRedis(StateManager):
         b"evicted",
     }
 
+    def _get_root_state(self, state: BaseState) -> BaseState:
+        """Chase parent_state pointers to find an instance of the top-level state.
+
+        Args:
+            state: The state to start from.
+
+        Returns:
+            An instance of the top-level state (self.state).
+        """
+        while type(state) != self.state and state.parent_state is not None:
+            state = state.parent_state
+        return state
+
+    async def _get_parent_state(self, token: str) -> BaseState | None:
+        """Get the parent state for the state requested in the token.
+
+        Args:
+            token: The token to get the state for (_substate_key).
+
+        Returns:
+            The parent state for the state requested by the token or None if there is no such parent.
+        """
+        parent_state = None
+        client_token, state_path = _split_substate_key(token)
+        parent_state_name = state_path.rpartition(".")[0]
+        if parent_state_name:
+            # Retrieve the parent state to populate event handlers onto this substate.
+            parent_state = await self.get_state(
+                token=_substate_key(client_token, parent_state_name),
+                top_level=False,
+                get_substates=False,
+            )
+        return parent_state
+
+    async def _populate_substates(
+        self,
+        token: str,
+        state: BaseState,
+        all_substates: bool = False,
+    ):
+        """Fetch and link substates for the given state instance.
+
+        There is no return value; the side-effect is that `state` will have `substates` populated,
+        and each substate will have its `parent_state` set to `state`.
+
+        Args:
+            token: The token to get the state for.
+            state: The state instance to populate substates for.
+            all_substates: Whether to fetch all substates or just required substates.
+        """
+        client_token, _ = _split_substate_key(token)
+
+        if all_substates:
+            # All substates are requested.
+            fetch_substates = state.get_substates()
+        else:
+            # Only _potentially_dirty_substates need to be fetched to recalc computed vars.
+            fetch_substates = state._potentially_dirty_substates()
+
+        tasks = {}
+        # Retrieve the necessary substates from redis.
+        for substate_cls in fetch_substates:
+            substate_name = substate_cls.get_name()
+            tasks[substate_name] = asyncio.create_task(
+                self.get_state(
+                    token=_substate_key(client_token, substate_cls),
+                    top_level=False,
+                    get_substates=all_substates,
+                    parent_state=state,
+                )
+            )
+
+        for substate_name, substate_task in tasks.items():
+            state.substates[substate_name] = await substate_task
+
     async def get_state(
         self,
         token: str,
@@ -1798,8 +2217,8 @@ class StateManagerRedis(StateManager):
 
         Args:
             token: The token to get the state for.
-            top_level: If true, return an instance of the top-level state.
-            get_substates: If true, also retrieve substates
+            top_level: If true, return an instance of the top-level state (self.state).
+            get_substates: If true, also retrieve substates.
             parent_state: If provided, use this parent_state instead of getting it from redis.
 
         Returns:
@@ -1809,7 +2228,7 @@ class StateManagerRedis(StateManager):
             RuntimeError: when the state_cls is not specified in the token
         """
         # Split the actual token from the fully qualified substate name.
-        client_token, _, state_path = token.partition("_")
+        _, state_path = _split_substate_key(token)
         if state_path:
             # Get the State class associated with the given path.
             state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
@@ -1825,66 +2244,49 @@ class StateManagerRedis(StateManager):
             # Deserialize the substate.
             state = cloudpickle.loads(redis_state)
 
-            # Populate parent and substates if requested.
+            # Populate parent state if missing and requested.
             if parent_state is None:
-                # Retrieve the parent state from redis.
-                parent_state_name = state_path.rpartition(".")[0]
-                if parent_state_name:
-                    parent_state_key = token.rpartition(".")[0]
-                    parent_state = await self.get_state(
-                        parent_state_key, top_level=False, get_substates=False
-                    )
+                parent_state = await self._get_parent_state(token)
             # Set up Bidirectional linkage between this state and its parent.
             if parent_state is not None:
                 parent_state.substates[state.get_name()] = state
                 state.parent_state = parent_state
-            if get_substates:
-                # Retrieve all substates from redis.
-                for substate_cls in state_cls.get_substates():
-                    substate_name = substate_cls.get_name()
-                    substate_key = token + "." + substate_name
-                    state.substates[substate_name] = await self.get_state(
-                        substate_key, top_level=False, parent_state=state
-                    )
+            # Populate substates if requested.
+            await self._populate_substates(token, state, all_substates=get_substates)
+
             # To retain compatibility with previous implementation, by default, we return
             # the top-level state by chasing `parent_state` pointers up the tree.
             if top_level:
-                while type(state) != self.state and state.parent_state is not None:
-                    state = state.parent_state
+                return self._get_root_state(state)
             return state
 
-        # Key didn't exist so we have to create a new entry for this token.
+        # TODO: dedupe the following logic with the above block
+        # Key didn't exist so we have to create a new instance for this token.
         if parent_state is None:
-            parent_state_name = state_path.rpartition(".")[0]
-            if parent_state_name:
-                # Retrieve the parent state to populate event handlers onto this substate.
-                parent_state_key = client_token + "_" + parent_state_name
-                parent_state = await self.get_state(
-                    parent_state_key, top_level=False, get_substates=False
-                )
-        # Persist the new state class to redis.
-        await self.set_state(
-            token,
-            state_cls(
-                parent_state=parent_state,
-                init_substates=False,
-            ),
-        )
-        # After creating the state key, recursively call `get_state` to populate substates.
-        return await self.get_state(
-            token,
-            top_level=top_level,
-            get_substates=get_substates,
+            parent_state = await self._get_parent_state(token)
+        # Instantiate the new state class (but don't persist it yet).
+        state = state_cls(
             parent_state=parent_state,
+            init_substates=False,
+            _reflex_internal_init=True,
         )
+        # Set up Bidirectional linkage between this state and its parent.
+        if parent_state is not None:
+            parent_state.substates[state.get_name()] = state
+            state.parent_state = parent_state
+        # Populate substates for the newly created state.
+        await self._populate_substates(token, state, all_substates=get_substates)
+        # To retain compatibility with previous implementation, by default, we return
+        # the top-level state by chasing `parent_state` pointers up the tree.
+        if top_level:
+            return self._get_root_state(state)
+        return state
 
     async def set_state(
         self,
         token: str,
         state: BaseState,
         lock_id: bytes | None = None,
-        set_substates: bool = True,
-        set_parent_state: bool = True,
     ):
         """Set the state for a token.
 
@@ -1892,11 +2294,10 @@ class StateManagerRedis(StateManager):
             token: The token to set the state for.
             state: The state to set.
             lock_id: If provided, the lock_key must be set to this value to set the state.
-            set_substates: If True, write substates to redis
-            set_parent_state: If True, write parent state to redis
 
         Raises:
             LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
+            RuntimeError: If the state instance doesn't match the state name in the token.
         """
         # Check that we're holding the lock.
         if (
@@ -1908,28 +2309,36 @@ class StateManagerRedis(StateManager):
                 f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
                 "or use `@rx.background` decorator for long-running tasks."
             )
-        # Find the substate associated with the token.
-        state_path = token.partition("_")[2]
-        if state_path and state.get_full_name() != state_path:
-            state = state.get_substate(tuple(state_path.split(".")))
-        # Persist the parent state separately, if requested.
-        if state.parent_state is not None and set_parent_state:
-            parent_state_key = token.rpartition(".")[0]
-            await self.set_state(
-                parent_state_key,
-                state.parent_state,
-                lock_id=lock_id,
-                set_substates=False,
+        client_token, substate_name = _split_substate_key(token)
+        # If the substate name on the token doesn't match the instance name, it cannot have a parent.
+        if state.parent_state is not None and state.get_full_name() != substate_name:
+            raise RuntimeError(
+                f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
             )
-        # Persist the substates separately, if requested.
-        if set_substates:
-            for substate_name, substate in state.substates.items():
-                substate_key = token + "." + substate_name
-                await self.set_state(
-                    substate_key, substate, lock_id=lock_id, set_parent_state=False
+
+        # Recursively set_state on all known substates.
+        tasks = []
+        for substate in state.substates.values():
+            tasks.append(
+                asyncio.create_task(
+                    self.set_state(
+                        token=_substate_key(client_token, substate),
+                        state=substate,
+                        lock_id=lock_id,
+                    )
                 )
+            )
         # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
-        await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
+        if state._get_was_touched():
+            await self.redis.set(
+                _substate_key(client_token, state),
+                cloudpickle.dumps(state),
+                ex=self.token_expiration,
+            )
+
+        # Wait for substates to be persisted.
+        for t in tasks:
+            await t
 
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
@@ -1957,7 +2366,7 @@ class StateManagerRedis(StateManager):
             The redis lock key for the token.
         """
         # All substates share the same lock domain, so ignore any substate path suffix.
-        client_token = token.partition("_")[0]
+        client_token = _split_substate_key(token)[0]
         return f"{client_token}_lock".encode()
 
     async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
@@ -2052,6 +2461,16 @@ class StateManagerRedis(StateManager):
         await self.redis.close(close_connection_pool=True)
 
 
+def get_state_manager() -> StateManager:
+    """Get the state manager for the app that is currently running.
+
+    Returns:
+        The state manager.
+    """
+    app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+    return app.state_manager
+
+
 class ClientStorageBase:
     """Base class for client-side storage."""
 

+ 6 - 0
reflex/testing.py

@@ -70,6 +70,10 @@ else:
     FRONTEND_POPEN_ARGS["start_new_session"] = True
 
 
+# Save a copy of internal substates to reset after each test.
+INTERNAL_STATES = State.class_subclasses.copy()
+
+
 # borrowed from py3.11
 class chdir(contextlib.AbstractContextManager):
     """Non thread-safe context manager to change the current working directory."""
@@ -220,6 +224,8 @@ class AppHarness:
             reflex.config.get_config(reload=True)
             # reset rx.State subclasses
             State.class_subclasses.clear()
+            State.class_subclasses.update(INTERNAL_STATES)
+            State._always_dirty_substates = set()
             State.get_class_substate.cache_clear()
             # Ensure the AppHarness test does not skip State assignment due to running via pytest
             os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)

+ 9 - 0
reflex/utils/exec.py

@@ -285,3 +285,12 @@ def output_system_info():
     console.debug(f"Using package executer at: {prerequisites.get_package_manager()}")  # type: ignore
     if system != "Windows":
         console.debug(f"Unzip path: {path_ops.which('unzip')}")
+
+
+def is_testing_env() -> bool:
+    """Whether the app is running in a testing environment.
+
+    Returns:
+        True if the app is running in under pytest.
+    """
+    return constants.PYTEST_CURRENT_TEST in os.environ

+ 10 - 0
reflex/vars.py

@@ -1875,6 +1875,10 @@ class ComputedVar(Var, property):
 
         Returns:
             A set of variable names accessed by the given obj.
+
+        Raises:
+            ValueError: if the function references the get_state, parent_state, or substates attributes
+                (cannot track deps in a related state, only implicitly via parent state).
         """
         d = set()
         if obj is None:
@@ -1898,6 +1902,8 @@ class ComputedVar(Var, property):
         if self_name is None:
             # cannot reference attributes on self if method takes no args
             return set()
+
+        invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
         self_is_top_of_stack = False
         for instruction in dis.get_instructions(obj):
             if (
@@ -1916,6 +1922,10 @@ class ComputedVar(Var, property):
                     ref_obj = getattr(objclass, instruction.argval)
                 except Exception:
                     ref_obj = None
+                if instruction.argval in invalid_names:
+                    raise ValueError(
+                        f"Cached var {self._var_full_name} cannot access arbitrary state via `{instruction.argval}`."
+                    )
                 if callable(ref_obj):
                     # recurse into callable attributes
                     d.update(

+ 17 - 10
tests/test_app.py

@@ -29,7 +29,15 @@ from reflex.components.radix.themes.typography.text import Text
 from reflex.event import Event
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
-from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
+from reflex.state import (
+    BaseState,
+    OnLoadInternalState,
+    RouterData,
+    State,
+    StateManagerRedis,
+    StateUpdate,
+    _substate_key,
+)
 from reflex.style import Style
 from reflex.utils import format
 from reflex.vars import ComputedVar
@@ -362,7 +370,7 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str):
     assert app.state == test_state
 
     # Get a state for a given token.
-    state = await app.state_manager.get_state(f"{token}_{test_state.get_full_name()}")
+    state = await app.state_manager.get_state(_substate_key(token, test_state))
     assert isinstance(state, test_state)
     assert state.var == 0  # type: ignore
 
@@ -766,8 +774,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
     # The App state must be the "root" of the state tree
     app = App(state=State)
     app.event_namespace.emit = AsyncMock()  # type: ignore
-    substate_token = f"{token}_{state.get_full_name()}"
-    current_state = await app.state_manager.get_state(substate_token)
+    current_state = await app.state_manager.get_state(_substate_key(token, state))
     data = b"This is binary data"
 
     # Create a binary IO object and write data to it
@@ -796,7 +803,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
             == StateUpdate(delta=delta, events=[], final=True).json() + "\n"
         )
 
-    current_state = await app.state_manager.get_state(substate_token)
+    current_state = await app.state_manager.get_state(_substate_key(token, state))
     state_dict = current_state.dict()[state.get_full_name()]
     assert state_dict["img_list"] == [
         "image1.jpg",
@@ -913,7 +920,7 @@ class DynamicState(BaseState):
         # self.side_effect_counter = self.side_effect_counter + 1
         return self.dynamic
 
-    on_load_internal = State.on_load_internal.fn
+    on_load_internal = OnLoadInternalState.on_load_internal.fn
 
 
 @pytest.mark.asyncio
@@ -950,7 +957,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     }
     assert constants.ROUTER in app.state()._computed_var_dependencies
 
-    substate_token = f"{token}_{DynamicState.get_full_name()}"
+    substate_token = _substate_key(token, DynamicState)
     sid = "mock_sid"
     client_ip = "127.0.0.1"
     state = await app.state_manager.get_state(substate_token)
@@ -978,7 +985,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     prev_exp_val = ""
     for exp_index, exp_val in enumerate(exp_vals):
         on_load_internal = _event(
-            name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL}",
+            name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}",
             val=exp_val,
         )
         exp_router_data = {
@@ -1013,8 +1020,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                     name="on_load",
                     val=exp_val,
                 ),
-                _dynamic_state_event(
-                    name="set_is_hydrated",
+                _event(
+                    name="state.set_is_hydrated",
                     payload={"value": True},
                     val=exp_val,
                     router_data={},

+ 207 - 17
tests/test_state.py

@@ -23,6 +23,7 @@ from reflex.state import (
     ImmutableStateError,
     LockExpiredError,
     MutableProxy,
+    OnLoadInternalState,
     RouterData,
     State,
     StateManager,
@@ -30,6 +31,7 @@ from reflex.state import (
     StateManagerRedis,
     StateProxy,
     StateUpdate,
+    _substate_key,
 )
 from reflex.utils import prerequisites, types
 from reflex.utils.format import json_dumps
@@ -139,6 +141,12 @@ class ChildState2(TestState):
     value: str
 
 
+class ChildState3(TestState):
+    """A child state fixture."""
+
+    value: str
+
+
 class GrandchildState(ChildState):
     """A grandchild state fixture."""
 
@@ -149,6 +157,32 @@ class GrandchildState(ChildState):
         pass
 
 
+class GrandchildState2(ChildState2):
+    """A grandchild state fixture."""
+
+    @rx.cached_var
+    def cached(self) -> str:
+        """A cached var.
+
+        Returns:
+            The value.
+        """
+        return self.value
+
+
+class GrandchildState3(ChildState3):
+    """A great grandchild state fixture."""
+
+    @rx.var
+    def computed(self) -> str:
+        """A computed var.
+
+        Returns:
+            The value.
+        """
+        return self.value
+
+
 class DateTimeState(BaseState):
     """A State with some datetime fields."""
 
@@ -329,6 +363,9 @@ def test_dict(test_state):
         "test_state.child_state",
         "test_state.child_state.grandchild_state",
         "test_state.child_state2",
+        "test_state.child_state2.grandchild_state2",
+        "test_state.child_state3",
+        "test_state.child_state3.grandchild_state3",
     }
     test_state_dict = test_state.dict()
     assert set(test_state_dict) == substates
@@ -380,10 +417,11 @@ def test_get_parent_state():
 
 def test_get_substates():
     """Test getting the substates."""
-    assert TestState.get_substates() == {ChildState, ChildState2}
+    assert TestState.get_substates() == {ChildState, ChildState2, ChildState3}
     assert ChildState.get_substates() == {GrandchildState}
-    assert ChildState2.get_substates() == set()
+    assert ChildState2.get_substates() == {GrandchildState2}
     assert GrandchildState.get_substates() == set()
+    assert GrandchildState2.get_substates() == set()
 
 
 def test_get_name():
@@ -469,8 +507,8 @@ def test_set_parent_and_substates(test_state, child_state, grandchild_state):
         child_state: A child state.
         grandchild_state: A grandchild state.
     """
-    assert len(test_state.substates) == 2
-    assert set(test_state.substates) == {"child_state", "child_state2"}
+    assert len(test_state.substates) == 3
+    assert set(test_state.substates) == {"child_state", "child_state2", "child_state3"}
 
     assert child_state.parent_state == test_state
     assert len(child_state.substates) == 1
@@ -655,7 +693,7 @@ def test_reset(test_state, child_state):
     assert child_state.dirty_vars == {"count", "value"}
 
     # The dirty substates should be reset.
-    assert test_state.dirty_substates == {"child_state", "child_state2"}
+    assert test_state.dirty_substates == {"child_state", "child_state2", "child_state3"}
 
 
 @pytest.mark.asyncio
@@ -675,7 +713,10 @@ async def test_process_event_simple(test_state):
 
     # The delta should contain the changes, including computed vars.
     # assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
-    assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
+    assert update.delta == {
+        "test_state": {"num1": 69, "sum": 72.14, "upper": ""},
+        "test_state.child_state3.grandchild_state3": {"computed": ""},
+    }
     assert update.events == []
 
 
@@ -700,6 +741,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     assert update.delta == {
         "test_state": {"sum": 3.14, "upper": ""},
         "test_state.child_state": {"value": "HI", "count": 24},
+        "test_state.child_state3.grandchild_state3": {"computed": ""},
     }
     test_state._clean()
 
@@ -715,6 +757,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     assert update.delta == {
         "test_state": {"sum": 3.14, "upper": ""},
         "test_state.child_state.grandchild_state": {"value2": "new"},
+        "test_state.child_state3.grandchild_state3": {"computed": ""},
     }
 
 
@@ -1443,7 +1486,7 @@ def substate_token(state_manager, token):
     Returns:
         Token concatenated with the state_manager's state full_name.
     """
-    return f"{token}_{state_manager.state.get_full_name()}"
+    return _substate_key(token, state_manager.state)
 
 
 @pytest.mark.asyncio
@@ -1545,7 +1588,7 @@ def substate_token_redis(state_manager_redis, token):
     Returns:
         Token concatenated with the state_manager's state full_name.
     """
-    return f"{token}_{state_manager_redis.state.get_full_name()}"
+    return _substate_key(token, state_manager_redis.state)
 
 
 @pytest.mark.asyncio
@@ -1670,6 +1713,22 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
         # cannot directly modify state proxy outside of async context
         sp.value2 = "16"
 
+    with pytest.raises(ImmutableStateError):
+        # Cannot get_state
+        await sp.get_state(ChildState)
+
+    with pytest.raises(ImmutableStateError):
+        # Cannot access get_substate
+        sp.get_substate([])
+
+    with pytest.raises(ImmutableStateError):
+        # Cannot access parent state
+        sp.parent_state.get_name()
+
+    with pytest.raises(ImmutableStateError):
+        # Cannot access substates
+        sp.substates[""]
+
     async with sp:
         assert sp._self_actx is not None
         assert sp._self_mutable  # proxy is mutable inside context
@@ -1685,8 +1744,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     assert sp.value2 == "42"
 
     # Get the state from the state manager directly and check that the value is updated
-    gc_token = f"{grandchild_state.get_token()}_{grandchild_state.get_full_name()}"
-    gotten_state = await mock_app.state_manager.get_state(gc_token)
+    gotten_state = await mock_app.state_manager.get_state(
+        _substate_key(grandchild_state.router.session.client_token, grandchild_state)
+    )
     if isinstance(mock_app.state_manager, StateManagerMemory):
         # For in-process store, only one instance of the state exists
         assert gotten_state is parent_state
@@ -1710,6 +1770,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
             grandchild_state.get_full_name(): {
                 "value2": "42",
             },
+            GrandchildState3.get_full_name(): {
+                "computed": "",
+            },
         }
     )
     assert mcall.kwargs["to"] == grandchild_state.get_sid()
@@ -1879,8 +1942,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         "private",
     ]
 
-    substate_token = f"{token}_{BackgroundTaskState.get_name()}"
-    assert (await mock_app.state_manager.get_state(substate_token)).order == exp_order
+    assert (
+        await mock_app.state_manager.get_state(
+            _substate_key(token, BackgroundTaskState)
+        )
+    ).order == exp_order
 
     assert mock_app.event_namespace is not None
     emit_mock = mock_app.event_namespace.emit
@@ -1957,8 +2023,11 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
         await task
     assert not mock_app.background_tasks
 
-    substate_token = f"{token}_{BackgroundTaskState.get_name()}"
-    assert (await mock_app.state_manager.get_state(substate_token)).order == [
+    assert (
+        await mock_app.state_manager.get_state(
+            _substate_key(token, BackgroundTaskState)
+        )
+    ).order == [
         "reset",
     ]
 
@@ -2246,7 +2315,7 @@ def test_mutable_copy_vars(mutable_state, copy_func):
 
 
 def test_duplicate_substate_class(mocker):
-    mocker.patch("reflex.state.os.environ", {})
+    mocker.patch("reflex.state.is_testing_env", lambda: False)
     with pytest.raises(ValueError):
 
         class TestState(BaseState):
@@ -2435,7 +2504,9 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
         expected: Expected delta.
         mocker: pytest mock object.
     """
-    mocker.patch("reflex.state.State.class_subclasses", {test_state})
+    mocker.patch(
+        "reflex.state.State.class_subclasses", {test_state, OnLoadInternalState}
+    )
     app = app_module_mock.app = App(
         state=State, load_events={"index": [test_state.test_handler]}
     )
@@ -2476,7 +2547,9 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
         token: A token.
         mocker: pytest mock object.
     """
-    mocker.patch("reflex.state.State.class_subclasses", {OnLoadState})
+    mocker.patch(
+        "reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState}
+    )
     app = app_module_mock.app = App(
         state=State,
         load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
@@ -2510,3 +2583,120 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
         OnLoadState.get_full_name(): {"num": 2}
     }
     assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
+
+
+@pytest.mark.asyncio
+async def test_get_state(mock_app: rx.App, token: str):
+    """Test that a get_state populates the top level state and delta calculation is correct.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+    mock_app.state_manager.state = mock_app.state = TestState
+
+    # Get instance of ChildState2.
+    test_state = await mock_app.state_manager.get_state(
+        _substate_key(token, ChildState2)
+    )
+    assert isinstance(test_state, TestState)
+    if isinstance(mock_app.state_manager, StateManagerMemory):
+        # All substates are available
+        assert tuple(sorted(test_state.substates)) == (
+            "child_state",
+            "child_state2",
+            "child_state3",
+        )
+    else:
+        # Sibling states are only populated if they have computed vars
+        assert tuple(sorted(test_state.substates)) == ("child_state2", "child_state3")
+
+    # Because ChildState3 has a computed var, it is always dirty, and always populated.
+    assert (
+        test_state.substates["child_state3"].substates["grandchild_state3"].computed
+        == ""
+    )
+
+    # Get the child_state2 directly.
+    child_state2_direct = test_state.get_substate(["child_state2"])
+    child_state2_get_state = await test_state.get_state(ChildState2)
+    # These should be the same object.
+    assert child_state2_direct is child_state2_get_state
+
+    # Get arbitrary GrandchildState.
+    grandchild_state = await child_state2_get_state.get_state(GrandchildState)
+    assert isinstance(grandchild_state, GrandchildState)
+
+    # Now the original root should have all substates populated.
+    assert tuple(sorted(test_state.substates)) == (
+        "child_state",
+        "child_state2",
+        "child_state3",
+    )
+
+    # ChildState should be retrievable
+    child_state_direct = test_state.get_substate(["child_state"])
+    child_state_get_state = await test_state.get_state(ChildState)
+    # These should be the same object.
+    assert child_state_direct is child_state_get_state
+
+    # GrandchildState instance should be the same as the one retrieved from the child_state2.
+    assert grandchild_state is child_state_direct.get_substate(["grandchild_state"])
+    grandchild_state.value2 = "set_value"
+
+    assert test_state.get_delta() == {
+        TestState.get_full_name(): {
+            "sum": 3.14,
+            "upper": "",
+        },
+        GrandchildState.get_full_name(): {
+            "value2": "set_value",
+        },
+        GrandchildState3.get_full_name(): {
+            "computed": "",
+        },
+    }
+
+    # Get a fresh instance
+    new_test_state = await mock_app.state_manager.get_state(
+        _substate_key(token, ChildState2)
+    )
+    assert isinstance(new_test_state, TestState)
+    if isinstance(mock_app.state_manager, StateManagerMemory):
+        # In memory, it's the same instance
+        assert new_test_state is test_state
+        test_state._clean()
+        # All substates are available
+        assert tuple(sorted(new_test_state.substates)) == (
+            "child_state",
+            "child_state2",
+            "child_state3",
+        )
+    else:
+        # With redis, we get a whole new instance
+        assert new_test_state is not test_state
+        # Sibling states are only populated if they have computed vars
+        assert tuple(sorted(new_test_state.substates)) == (
+            "child_state2",
+            "child_state3",
+        )
+
+    # Set a value on child_state2, should update cached var in grandchild_state2
+    child_state2 = new_test_state.get_substate(("child_state2",))
+    child_state2.value = "set_c2_value"
+
+    assert new_test_state.get_delta() == {
+        TestState.get_full_name(): {
+            "sum": 3.14,
+            "upper": "",
+        },
+        ChildState2.get_full_name(): {
+            "value": "set_c2_value",
+        },
+        GrandchildState2.get_full_name(): {
+            "cached": "set_c2_value",
+        },
+        GrandchildState3.get_full_name(): {
+            "computed": "",
+        },
+    }

+ 371 - 0
tests/test_state_tree.py

@@ -0,0 +1,371 @@
+"""Specialized test for a larger state tree."""
+import asyncio
+from typing import Generator
+
+import pytest
+
+import reflex as rx
+from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key
+
+
+class Root(BaseState):
+    """Root of the state tree."""
+
+    root: int
+
+
+class TreeA(Root):
+    """TreeA is a child of Root."""
+
+    a: int
+
+
+class SubA_A(TreeA):
+    """SubA_A is a child of TreeA."""
+
+    sub_a_a: int
+
+
+class SubA_A_A(SubA_A):
+    """SubA_A_A is a child of SubA_A."""
+
+    sub_a_a_a: int
+
+
+class SubA_A_A_A(SubA_A_A):
+    """SubA_A_A_A is a child of SubA_A_A."""
+
+    sub_a_a_a_a: int
+
+
+class SubA_A_A_B(SubA_A_A):
+    """SubA_A_A_B is a child of SubA_A_A."""
+
+    @rx.cached_var
+    def sub_a_a_a_cached(self) -> int:
+        """A cached var.
+
+        Returns:
+            The value of sub_a_a_a + 1
+        """
+        return self.sub_a_a_a + 1
+
+
+class SubA_A_A_C(SubA_A_A):
+    """SubA_A_A_C is a child of SubA_A_A."""
+
+    sub_a_a_a_c: int
+
+
+class SubA_A_B(SubA_A):
+    """SubA_A_B is a child of SubA_A."""
+
+    sub_a_a_b: int
+
+
+class SubA_B(TreeA):
+    """SubA_B is a child of TreeA."""
+
+    sub_a_b: int
+
+
+class TreeB(Root):
+    """TreeB is a child of Root."""
+
+    b: int
+
+
+class SubB_A(TreeB):
+    """SubB_A is a child of TreeB."""
+
+    sub_b_a: int
+
+
+class SubB_B(TreeB):
+    """SubB_B is a child of TreeB."""
+
+    sub_b_b: int
+
+
+class SubB_C(TreeB):
+    """SubB_C is a child of TreeB."""
+
+    sub_b_c: int
+
+
+class SubB_C_A(SubB_C):
+    """SubB_C_A is a child of SubB_C."""
+
+    sub_b_c_a: int
+
+
+class TreeC(Root):
+    """TreeC is a child of Root."""
+
+    c: int
+
+
+class SubC_A(TreeC):
+    """SubC_A is a child of TreeC."""
+
+    sub_c_a: int
+
+
+class TreeD(Root):
+    """TreeD is a child of Root."""
+
+    d: int
+
+    @rx.var
+    def d_var(self) -> int:
+        """A computed var.
+
+        Returns:
+            The value of d + 1
+        """
+        return self.d + 1
+
+
+class TreeE(Root):
+    """TreeE is a child of Root."""
+
+    e: int
+
+
+class SubE_A(TreeE):
+    """SubE_A is a child of TreeE."""
+
+    sub_e_a: int
+
+
+class SubE_A_A(SubE_A):
+    """SubE_A_A is a child of SubE_A."""
+
+    sub_e_a_a: int
+
+
+class SubE_A_A_A(SubE_A_A):
+    """SubE_A_A_A is a child of SubE_A_A."""
+
+    sub_e_a_a_a: int
+
+
+class SubE_A_A_A_A(SubE_A_A_A):
+    """SubE_A_A_A_A is a child of SubE_A_A_A."""
+
+    sub_e_a_a_a_a: int
+
+    @rx.var
+    def sub_e_a_a_a_a_var(self) -> int:
+        """A computed var.
+
+        Returns:
+            The value of sub_e_a_a_a_a + 1
+        """
+        return self.sub_e_a_a_a + 1
+
+
+class SubE_A_A_A_B(SubE_A_A_A):
+    """SubE_A_A_A_B is a child of SubE_A_A_A."""
+
+    sub_e_a_a_a_b: int
+
+
+class SubE_A_A_A_C(SubE_A_A_A):
+    """SubE_A_A_A_C is a child of SubE_A_A_A."""
+
+    sub_e_a_a_a_c: int
+
+
+class SubE_A_A_A_D(SubE_A_A_A):
+    """SubE_A_A_A_D is a child of SubE_A_A_A."""
+
+    sub_e_a_a_a_d: int
+
+    @rx.cached_var
+    def sub_e_a_a_a_d_var(self) -> int:
+        """A computed var.
+
+        Returns:
+            The value of sub_e_a_a_a_a + 1
+        """
+        return self.sub_e_a_a_a + 1
+
+
+ALWAYS_COMPUTED_VARS = {
+    TreeD.get_full_name(): {"d_var": 1},
+    SubE_A_A_A_A.get_full_name(): {"sub_e_a_a_a_a_var": 1},
+}
+
+ALWAYS_COMPUTED_DICT_KEYS = [
+    Root.get_full_name(),
+    TreeD.get_full_name(),
+    TreeE.get_full_name(),
+    SubE_A.get_full_name(),
+    SubE_A_A.get_full_name(),
+    SubE_A_A_A.get_full_name(),
+    SubE_A_A_A_A.get_full_name(),
+    SubE_A_A_A_D.get_full_name(),
+]
+
+
+@pytest.fixture(scope="function")
+def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]:
+    """Instance of state manager for redis only.
+
+    Args:
+        app_module_mock: The app module mock fixture.
+
+    Yields:
+        A state manager instance
+    """
+    app_module_mock.app = rx.App(state=Root)
+    state_manager = app_module_mock.app.state_manager
+
+    if not isinstance(state_manager, StateManagerRedis):
+        pytest.skip("Test requires redis")
+
+    yield state_manager
+
+    asyncio.get_event_loop().run_until_complete(state_manager.close())
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    ("substate_cls", "exp_root_substates", "exp_root_dict_keys"),
+    [
+        (
+            Root,
+            ["tree_a", "tree_b", "tree_c", "tree_d", "tree_e"],
+            [
+                TreeA.get_full_name(),
+                SubA_A.get_full_name(),
+                SubA_A_A.get_full_name(),
+                SubA_A_A_A.get_full_name(),
+                SubA_A_A_B.get_full_name(),
+                SubA_A_A_C.get_full_name(),
+                SubA_A_B.get_full_name(),
+                SubA_B.get_full_name(),
+                TreeB.get_full_name(),
+                SubB_A.get_full_name(),
+                SubB_B.get_full_name(),
+                SubB_C.get_full_name(),
+                SubB_C_A.get_full_name(),
+                TreeC.get_full_name(),
+                SubC_A.get_full_name(),
+                SubE_A_A_A_B.get_full_name(),
+                SubE_A_A_A_C.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeA,
+            ("tree_a", "tree_d", "tree_e"),
+            [
+                TreeA.get_full_name(),
+                SubA_A.get_full_name(),
+                SubA_A_A.get_full_name(),
+                SubA_A_A_A.get_full_name(),
+                SubA_A_A_B.get_full_name(),
+                SubA_A_A_C.get_full_name(),
+                SubA_A_B.get_full_name(),
+                SubA_B.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            SubA_A_A_A,
+            ["tree_a", "tree_d", "tree_e"],
+            [
+                TreeA.get_full_name(),
+                SubA_A.get_full_name(),
+                SubA_A_A.get_full_name(),
+                SubA_A_A_A.get_full_name(),
+                SubA_A_A_B.get_full_name(),  # Cached var dep
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeB,
+            ["tree_b", "tree_d", "tree_e"],
+            [
+                TreeB.get_full_name(),
+                SubB_A.get_full_name(),
+                SubB_B.get_full_name(),
+                SubB_C.get_full_name(),
+                SubB_C_A.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            SubB_B,
+            ["tree_b", "tree_d", "tree_e"],
+            [
+                TreeB.get_full_name(),
+                SubB_B.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            SubB_C_A,
+            ["tree_b", "tree_d", "tree_e"],
+            [
+                TreeB.get_full_name(),
+                SubB_C.get_full_name(),
+                SubB_C_A.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeC,
+            ["tree_c", "tree_d", "tree_e"],
+            [
+                TreeC.get_full_name(),
+                SubC_A.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeD,
+            ["tree_d", "tree_e"],
+            [
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+        (
+            TreeE,
+            ["tree_d", "tree_e"],
+            [
+                # Extra siblings of computed var included now.
+                SubE_A_A_A_B.get_full_name(),
+                SubE_A_A_A_C.get_full_name(),
+                *ALWAYS_COMPUTED_DICT_KEYS,
+            ],
+        ),
+    ],
+)
+async def test_get_state_tree(
+    state_manager_redis,
+    token,
+    substate_cls,
+    exp_root_substates,
+    exp_root_dict_keys,
+):
+    """Test getting state trees and assert on which branches are retrieved.
+
+    Args:
+        state_manager_redis: The state manager redis fixture.
+        token: The token fixture.
+        substate_cls: The substate class to retrieve.
+        exp_root_substates: The expected substates of the root state.
+        exp_root_dict_keys: The expected keys of the root state dict.
+    """
+    state = await state_manager_redis.get_state(_substate_key(token, substate_cls))
+    assert isinstance(state, Root)
+    assert sorted(state.substates) == sorted(exp_root_substates)
+
+    # Only computed vars should be returned
+    assert state.get_delta() == ALWAYS_COMPUTED_VARS
+
+    # All of TreeA, TreeD, and TreeE substates should be in the dict
+    assert sorted(state.dict()) == sorted(exp_root_dict_keys)

+ 8 - 2
tests/utils/test_format.py

@@ -13,8 +13,11 @@ from reflex.vars import BaseVar, Var
 from tests.test_state import (
     ChildState,
     ChildState2,
+    ChildState3,
     DateTimeState,
     GrandchildState,
+    GrandchildState2,
+    GrandchildState3,
     TestState,
 )
 
@@ -649,7 +652,7 @@ formatted_router = {
     "input, output",
     [
         (
-            TestState().dict(),  # type: ignore
+            TestState(_reflex_internal_init=True).dict(),  # type: ignore
             {
                 TestState.get_full_name(): {
                     "array": [1, 2, 3.14],
@@ -674,11 +677,14 @@ formatted_router = {
                     "value": "",
                 },
                 ChildState2.get_full_name(): {"value": ""},
+                ChildState3.get_full_name(): {"value": ""},
                 GrandchildState.get_full_name(): {"value2": ""},
+                GrandchildState2.get_full_name(): {"cached": ""},
+                GrandchildState3.get_full_name(): {"computed": ""},
             },
         ),
         (
-            DateTimeState().dict(),
+            DateTimeState(_reflex_internal_init=True).dict(),  # type: ignore
             {
                 DateTimeState.get_full_name(): {
                     "d": "1989-11-09",