1
0
Эх сурвалжийг харах

[ENG-4326] Async ComputedVar (#4711)

* WiP

* Save the var from get_var_name

* flatten StateManagerRedis.get_state algorithm

simplify fetching of states and avoid repeatedly fetching the same state

* Get all the states in a single redis round-trip

* update docstrings in StateManagerRedis

* Move computed var dep tracking to separate module

* Fix pre-commit issues

* ComputedVar.add_dependency: explicitly dependency declaration

Allow var dependencies to be added at runtime, for example, when defining a
ComponentState that depends on vars that cannot be known statically.

Fix more pyright issues.

* Fix/ignore more pyright issues from recent merge

* handle cleaning out _potentially_dirty_states on reload

* ignore accessed attributes missing on state class

these might be added dynamically later in which case we recompute the
dependency tracking dicts... if not, they'll blow up anyway at runtime.

* fix playwright tests, which insist on running an asyncio loop

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
Masen Furer 3 сар өмнө
parent
commit
a2243190ff

+ 11 - 5
reflex/app.py

@@ -908,11 +908,17 @@ class App(MiddlewareMixin, LifespanMixin):
             if not var._cache:
             if not var._cache:
                 continue
                 continue
             deps = var._deps(objclass=state)
             deps = var._deps(objclass=state)
-            for dep in deps:
-                if dep not in state.vars and dep not in state.backend_vars:
-                    raise exceptions.VarDependencyError(
-                        f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}"
-                    )
+            for state_name, dep_set in deps.items():
+                state_cls = (
+                    state.get_root_state().get_class_substate(state_name)
+                    if state_name != state.get_full_name()
+                    else state
+                )
+                for dep in dep_set:
+                    if dep not in state_cls.vars and dep not in state_cls.backend_vars:
+                        raise exceptions.VarDependencyError(
+                            f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}"
+                        )
 
 
         for substate in state.class_subclasses:
         for substate in state.class_subclasses:
             self._validate_var_dependencies(substate)
             self._validate_var_dependencies(substate)

+ 22 - 2
reflex/compiler/utils.py

@@ -2,12 +2,15 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+import asyncio
+import concurrent.futures
 import traceback
 import traceback
 from datetime import datetime
 from datetime import datetime
 from pathlib import Path
 from pathlib import Path
 from typing import Any, Callable, Dict, Optional, Type, Union
 from typing import Any, Callable, Dict, Optional, Type, Union
 from urllib.parse import urlparse
 from urllib.parse import urlparse
 
 
+from reflex.utils.exec import is_in_app_harness
 from reflex.utils.prerequisites import get_web_dir
 from reflex.utils.prerequisites import get_web_dir
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
@@ -33,7 +36,7 @@ from reflex.components.base import (
 )
 )
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
 from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
-from reflex.state import BaseState
+from reflex.state import BaseState, _resolve_delta
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils.imports import ImportVar, ParsedImportDict
 from reflex.utils.imports import ImportVar, ParsedImportDict
@@ -177,7 +180,24 @@ def compile_state(state: Type[BaseState]) -> dict:
         initial_state = state(_reflex_internal_init=True).dict(
         initial_state = state(_reflex_internal_init=True).dict(
             initial=True, include_computed=False
             initial=True, include_computed=False
         )
         )
-    return initial_state
+    try:
+        _ = asyncio.get_running_loop()
+    except RuntimeError:
+        pass
+    else:
+        if is_in_app_harness():
+            # Playwright tests already have an event loop running, so we can't use asyncio.run.
+            with concurrent.futures.ThreadPoolExecutor() as pool:
+                resolved_initial_state = pool.submit(
+                    asyncio.run, _resolve_delta(initial_state)
+                ).result()
+                console.warn(
+                    f"Had to get initial state in a thread 🤮 {resolved_initial_state}",
+                )
+                return resolved_initial_state
+
+    # Normally the compile runs before any event loop starts, we asyncio.run is available for calling.
+    return asyncio.run(_resolve_delta(initial_state))
 
 
 
 
 def _compile_client_storage_field(
 def _compile_client_storage_field(

+ 2 - 2
reflex/middleware/hydrate_middleware.py

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
 from reflex import constants
 from reflex import constants
 from reflex.event import Event, get_hydrate_event
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware.middleware import Middleware
 from reflex.middleware.middleware import Middleware
-from reflex.state import BaseState, StateUpdate
+from reflex.state import BaseState, StateUpdate, _resolve_delta
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.app import App
     from reflex.app import App
@@ -42,7 +42,7 @@ class HydrateMiddleware(Middleware):
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
 
 
         # Get the initial state.
         # Get the initial state.
-        delta = state.dict()
+        delta = await _resolve_delta(state.dict())
         # since a full dict was captured, clean any dirtiness
         # since a full dict was captured, clean any dirtiness
         state._clean()
         state._clean()
 
 

+ 253 - 323
reflex/state.py

@@ -15,7 +15,6 @@ import time
 import typing
 import typing
 import uuid
 import uuid
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from collections import defaultdict
 from hashlib import md5
 from hashlib import md5
 from pathlib import Path
 from pathlib import Path
 from types import FunctionType, MethodType
 from types import FunctionType, MethodType
@@ -329,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
     )
     )
 
 
 
 
+async def _resolve_delta(delta: Delta) -> Delta:
+    """Await all coroutines in the delta.
+
+    Args:
+        delta: The delta to process.
+
+    Returns:
+        The same delta dict with all coroutines resolved to their return value.
+    """
+    tasks = {}
+    for state_name, state_delta in delta.items():
+        for var_name, value in state_delta.items():
+            if asyncio.iscoroutine(value):
+                tasks[state_name, var_name] = asyncio.create_task(value)
+    for (state_name, var_name), task in tasks.items():
+        delta[state_name][var_name] = await task
+    return delta
+
+
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
     """The state of the app."""
 
 
@@ -356,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # A set of subclassses of this class.
     # A set of subclassses of this class.
     class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
     class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
 
 
-    # Mapping of var name to set of computed variables that depend on it
-    _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
-
-    # Mapping of var name to set of substates that depend on it
-    _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
+    # Mapping of var name to set of (state_full_name, var_name) that depend on it.
+    _var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {}
 
 
     # Set of vars which always need to be recomputed
     # Set of vars which always need to be recomputed
     _always_dirty_computed_vars: ClassVar[Set[str]] = set()
     _always_dirty_computed_vars: ClassVar[Set[str]] = set()
@@ -368,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # Set of substates which always need to be recomputed
     # Set of substates which always need to be recomputed
     _always_dirty_substates: ClassVar[Set[str]] = set()
     _always_dirty_substates: ClassVar[Set[str]] = set()
 
 
+    # Set of states which might need to be recomputed if vars in this state change.
+    _potentially_dirty_states: ClassVar[Set[str]] = set()
+
     # The parent state.
     # The parent state.
     parent_state: Optional[BaseState] = None
     parent_state: Optional[BaseState] = None
 
 
@@ -519,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         # Reset dirty substate tracking for this class.
         # Reset dirty substate tracking for this class.
         cls._always_dirty_substates = set()
         cls._always_dirty_substates = set()
+        cls._potentially_dirty_states = set()
 
 
         # Get the parent vars.
         # Get the parent vars.
         parent_state = cls.get_parent_state()
         parent_state = cls.get_parent_state()
@@ -622,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             setattr(cls, name, handler)
             setattr(cls, name, handler)
 
 
         # Initialize per-class var dependency tracking.
         # Initialize per-class var dependency tracking.
-        cls._computed_var_dependencies = defaultdict(set)
-        cls._substate_var_dependencies = defaultdict(set)
+        cls._var_dependencies = {}
         cls._init_var_dependency_dicts()
         cls._init_var_dependency_dicts()
 
 
     @staticmethod
     @staticmethod
@@ -768,26 +786,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         Additional updates tracking dicts for vars and substates that always
         Additional updates tracking dicts for vars and substates that always
         need to be recomputed.
         need to be recomputed.
         """
         """
-        inherited_vars = set(cls.inherited_vars).union(
-            set(cls.inherited_backend_vars),
-        )
         for cvar_name, cvar in cls.computed_vars.items():
         for cvar_name, cvar in cls.computed_vars.items():
-            # Add the dependencies.
-            for var in cvar._deps(objclass=cls):
-                cls._computed_var_dependencies[var].add(cvar_name)
-                if var in inherited_vars:
-                    # track that this substate depends on its parent for this var
-                    state_name = cls.get_name()
-                    parent_state = cls.get_parent_state()
-                    while parent_state is not None and var in {
-                        **parent_state.vars,
-                        **parent_state.backend_vars,
+            if not cvar._cache:
+                # Do not perform dep calculation when cache=False (these are always dirty).
+                continue
+            for state_name, dvar_set in cvar._deps(objclass=cls).items():
+                state_cls = cls.get_root_state().get_class_substate(state_name)
+                for dvar in dvar_set:
+                    defining_state_cls = state_cls
+                    while dvar in {
+                        *defining_state_cls.inherited_vars,
+                        *defining_state_cls.inherited_backend_vars,
                     }:
                     }:
-                        parent_state._substate_var_dependencies[var].add(state_name)
-                        state_name, parent_state = (
-                            parent_state.get_name(),
-                            parent_state.get_parent_state(),
-                        )
+                        parent_state = defining_state_cls.get_parent_state()
+                        if parent_state is not None:
+                            defining_state_cls = parent_state
+                    defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
+                        (cls.get_full_name(), cvar_name)
+                    )
+                    defining_state_cls._potentially_dirty_states.add(
+                        cls.get_full_name()
+                    )
 
 
         # ComputedVar with cache=False always need to be recomputed
         # ComputedVar with cache=False always need to be recomputed
         cls._always_dirty_computed_vars = {
         cls._always_dirty_computed_vars = {
@@ -902,6 +921,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             raise ValueError(f"Only one parent state is allowed {parent_states}.")
             raise ValueError(f"Only one parent state is allowed {parent_states}.")
         return parent_states[0] if len(parent_states) == 1 else None
         return parent_states[0] if len(parent_states) == 1 else None
 
 
+    @classmethod
+    @functools.lru_cache()
+    def get_root_state(cls) -> Type[BaseState]:
+        """Get the root state.
+
+        Returns:
+            The root state.
+        """
+        parent_state = cls.get_parent_state()
+        return cls if parent_state is None else parent_state.get_root_state()
+
     @classmethod
     @classmethod
     def get_substates(cls) -> set[Type[BaseState]]:
     def get_substates(cls) -> set[Type[BaseState]]:
         """Get the substates of the state.
         """Get the substates of the state.
@@ -1351,7 +1381,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         super().__setattr__(name, value)
         super().__setattr__(name, value)
 
 
         # Add the var to the dirty list.
         # Add the var to the dirty list.
-        if name in self.vars or name in self._computed_var_dependencies:
+        if name in self.base_vars:
             self.dirty_vars.add(name)
             self.dirty_vars.add(name)
             self._mark_dirty()
             self._mark_dirty()
 
 
@@ -1422,64 +1452,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return self.substates[path[0]].get_substate(path[1:])
         return self.substates[path[0]].get_substate(path[1:])
 
 
     @classmethod
     @classmethod
-    def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
-        """Find the name of the nearest common ancestor shared by this and the other state.
-
-        Args:
-            other: The other state.
-
-        Returns:
-            Full name of the nearest common ancestor.
-        """
-        common_ancestor_parts = []
-        for part1, part2 in zip(
-            cls.get_full_name().split("."),
-            other.get_full_name().split("."),
-            strict=True,
-        ):
-            if part1 != part2:
-                break
-            common_ancestor_parts.append(part1)
-        return ".".join(common_ancestor_parts)
-
-    @classmethod
-    def _determine_missing_parent_states(
-        cls, target_state_cls: Type[BaseState]
-    ) -> tuple[str, list[str]]:
-        """Determine the missing parent states between the target_state_cls and common ancestor of this state.
-
-        Args:
-            target_state_cls: The class of the state to find missing parent states for.
-
-        Returns:
-            The name of the common ancestor and the list of missing parent states.
-        """
-        common_ancestor_name = cls._get_common_ancestor(target_state_cls)
-        common_ancestor_parts = common_ancestor_name.split(".")
-        target_state_parts = tuple(target_state_cls.get_full_name().split("."))
-        relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
-
-        # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
-        fetch_parent_states = [common_ancestor_name]
-        for relative_parent_state_name in relative_target_state_parts:
-            fetch_parent_states.append(
-                ".".join((fetch_parent_states[-1], relative_parent_state_name))
-            )
-
-        return common_ancestor_name, fetch_parent_states[1:-1]
-
-    def _get_parent_states(self) -> list[tuple[str, BaseState]]:
-        """Get all parent state instances up to the root of the state tree.
+    def _get_potentially_dirty_states(cls) -> set[type[BaseState]]:
+        """Get substates which may have dirty vars due to dependencies.
 
 
         Returns:
         Returns:
-            A list of tuples containing the name and the instance of each parent state.
+            The set of potentially dirty substate classes.
         """
         """
-        parent_states_with_name = []
-        parent_state = self
-        while parent_state.parent_state is not None:
-            parent_state = parent_state.parent_state
-            parent_states_with_name.append((parent_state.get_full_name(), parent_state))
-        return parent_states_with_name
+        return {
+            cls.get_class_substate(substate_name)
+            for substate_name in cls._always_dirty_substates
+        }.union(
+            {
+                cls.get_root_state().get_class_substate(substate_name)
+                for substate_name in cls._potentially_dirty_states
+            }
+        )
 
 
     def _get_root_state(self) -> BaseState:
     def _get_root_state(self) -> BaseState:
         """Get the root state of the state tree.
         """Get the root state of the state tree.
@@ -1492,55 +1479,38 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             parent_state = parent_state.parent_state
             parent_state = parent_state.parent_state
         return parent_state
         return parent_state
 
 
-    async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
-        """Populate substates in the tree between the target_state_cls and common ancestor of this state.
+    async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
+        """Get a state instance from redis.
 
 
         Args:
         Args:
-            target_state_cls: The class of the state to populate parent states for.
+            state_cls: The class of the state.
 
 
         Returns:
         Returns:
-            The parent state instance of target_state_cls.
+            The instance of state_cls associated with this state's client_token.
 
 
         Raises:
         Raises:
             RuntimeError: If redis is not used in this backend process.
             RuntimeError: If redis is not used in this backend process.
+            StateMismatchError: If the state instance is not of the expected type.
         """
         """
+        # Then get the target state and all its substates.
         state_manager = get_state_manager()
         state_manager = get_state_manager()
         if not isinstance(state_manager, StateManagerRedis):
         if not isinstance(state_manager, StateManagerRedis):
             raise RuntimeError(
             raise RuntimeError(
-                f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
+                f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
                 "(All states should already be available -- this is likely a bug).",
                 "(All states should already be available -- this is likely a bug).",
             )
             )
+        state_in_redis = await state_manager.get_state(
+            token=_substate_key(self.router.session.client_token, state_cls),
+            top_level=False,
+            for_state_instance=self,
+        )
 
 
-        # Find the missing parent states up to the common ancestor.
-        (
-            common_ancestor_name,
-            missing_parent_states,
-        ) = self._determine_missing_parent_states(target_state_cls)
-
-        # Fetch all missing parent states and link them up to the common ancestor.
-        parent_states_tuple = self._get_parent_states()
-        root_state = parent_states_tuple[-1][1]
-        parent_states_by_name = dict(parent_states_tuple)
-        parent_state = parent_states_by_name[common_ancestor_name]
-        for parent_state_name in missing_parent_states:
-            try:
-                parent_state = root_state.get_substate(parent_state_name.split("."))
-                # The requested state is already cached, do NOT fetch it again.
-                continue
-            except ValueError:
-                # The requested state is missing, fetch from redis.
-                pass
-            parent_state = await state_manager.get_state(
-                token=_substate_key(
-                    self.router.session.client_token, parent_state_name
-                ),
-                top_level=False,
-                get_substates=False,
-                parent_state=parent_state,
+        if not isinstance(state_in_redis, state_cls):
+            raise StateMismatchError(
+                f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
             )
             )
 
 
-        # Return the direct parent of target_state_cls for subsequent linking.
-        return parent_state
+        return state_in_redis
 
 
     def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
     def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get a state instance from the cache.
         """Get a state instance from the cache.
@@ -1562,44 +1532,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             )
             )
         return substate
         return substate
 
 
-    async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
-        """Get a state instance from redis.
-
-        Args:
-            state_cls: The class of the state.
-
-        Returns:
-            The instance of state_cls associated with this state's client_token.
-
-        Raises:
-            RuntimeError: If redis is not used in this backend process.
-            StateMismatchError: If the state instance is not of the expected type.
-        """
-        # Fetch all missing parent states from redis.
-        parent_state_of_state_cls = await self._populate_parent_states(state_cls)
-
-        # Then get the target state and all its substates.
-        state_manager = get_state_manager()
-        if not isinstance(state_manager, StateManagerRedis):
-            raise RuntimeError(
-                f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
-                "(All states should already be available -- this is likely a bug).",
-            )
-
-        state_in_redis = await state_manager.get_state(
-            token=_substate_key(self.router.session.client_token, state_cls),
-            top_level=False,
-            get_substates=True,
-            parent_state=parent_state_of_state_cls,
-        )
-
-        if not isinstance(state_in_redis, state_cls):
-            raise StateMismatchError(
-                f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
-            )
-
-        return state_in_redis
-
     async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
     async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get an instance of the state associated with this token.
         """Get an instance of the state associated with this token.
 
 
@@ -1738,7 +1670,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
             f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
         )
         )
 
 
-    def _as_state_update(
+    async def _as_state_update(
         self,
         self,
         handler: EventHandler,
         handler: EventHandler,
         events: EventSpec | list[EventSpec] | None,
         events: EventSpec | list[EventSpec] | None,
@@ -1766,7 +1698,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         try:
         try:
             # Get the delta after processing the event.
             # Get the delta after processing the event.
-            delta = state.get_delta()
+            delta = await _resolve_delta(state.get_delta())
             state._clean()
             state._clean()
 
 
             return StateUpdate(
             return StateUpdate(
@@ -1866,24 +1798,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             # Handle async generators.
             # Handle async generators.
             if inspect.isasyncgen(events):
             if inspect.isasyncgen(events):
                 async for event in events:
                 async for event in events:
-                    yield state._as_state_update(handler, event, final=False)
-                yield state._as_state_update(handler, events=None, final=True)
+                    yield await state._as_state_update(handler, event, final=False)
+                yield await state._as_state_update(handler, events=None, final=True)
 
 
             # Handle regular generators.
             # Handle regular generators.
             elif inspect.isgenerator(events):
             elif inspect.isgenerator(events):
                 try:
                 try:
                     while True:
                     while True:
-                        yield state._as_state_update(handler, next(events), final=False)
+                        yield await state._as_state_update(
+                            handler, next(events), final=False
+                        )
                 except StopIteration as si:
                 except StopIteration as si:
                     # the "return" value of the generator is not available
                     # the "return" value of the generator is not available
                     # in the loop, we must catch StopIteration to access it
                     # in the loop, we must catch StopIteration to access it
                     if si.value is not None:
                     if si.value is not None:
-                        yield state._as_state_update(handler, si.value, final=False)
-                yield state._as_state_update(handler, events=None, final=True)
+                        yield await state._as_state_update(
+                            handler, si.value, final=False
+                        )
+                yield await state._as_state_update(handler, events=None, final=True)
 
 
             # Handle regular event chains.
             # Handle regular event chains.
             else:
             else:
-                yield state._as_state_update(handler, events, final=True)
+                yield await state._as_state_update(handler, events, final=True)
 
 
         # If an error occurs, throw a window alert.
         # If an error occurs, throw a window alert.
         except Exception as ex:
         except Exception as ex:
@@ -1893,7 +1829,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
                 prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
             )
             )
 
 
-            yield state._as_state_update(
+            yield await state._as_state_update(
                 handler,
                 handler,
                 event_specs,
                 event_specs,
                 final=True,
                 final=True,
@@ -1901,15 +1837,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
     def _mark_dirty_computed_vars(self) -> None:
     def _mark_dirty_computed_vars(self) -> None:
         """Mark ComputedVars that need to be recalculated based on dirty_vars."""
         """Mark ComputedVars that need to be recalculated based on dirty_vars."""
+        # Append expired computed vars to dirty_vars to trigger recalculation
+        self.dirty_vars.update(self._expired_computed_vars())
+        # Append always dirty computed vars to dirty_vars to trigger recalculation
+        self.dirty_vars.update(self._always_dirty_computed_vars)
+
         dirty_vars = self.dirty_vars
         dirty_vars = self.dirty_vars
         while dirty_vars:
         while dirty_vars:
             calc_vars, dirty_vars = dirty_vars, set()
             calc_vars, dirty_vars = dirty_vars, set()
-            for cvar in self._dirty_computed_vars(from_vars=calc_vars):
-                self.dirty_vars.add(cvar)
+            for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars):
+                if state_name == self.get_full_name():
+                    defining_state = self
+                else:
+                    defining_state = self._get_root_state().get_substate(
+                        tuple(state_name.split("."))
+                    )
+                defining_state.dirty_vars.add(cvar)
                 dirty_vars.add(cvar)
                 dirty_vars.add(cvar)
-                actual_var = self.computed_vars.get(cvar)
+                actual_var = defining_state.computed_vars.get(cvar)
                 if actual_var is not None:
                 if actual_var is not None:
-                    actual_var.mark_dirty(instance=self)
+                    actual_var.mark_dirty(instance=defining_state)
+                if defining_state is not self:
+                    defining_state._mark_dirty()
 
 
     def _expired_computed_vars(self) -> set[str]:
     def _expired_computed_vars(self) -> set[str]:
         """Determine ComputedVars that need to be recalculated based on the expiration time.
         """Determine ComputedVars that need to be recalculated based on the expiration time.
@@ -1925,7 +1874,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
     def _dirty_computed_vars(
     def _dirty_computed_vars(
         self, from_vars: set[str] | None = None, include_backend: bool = True
         self, from_vars: set[str] | None = None, include_backend: bool = True
-    ) -> set[str]:
+    ) -> set[tuple[str, str]]:
         """Determine ComputedVars that need to be recalculated based on the given vars.
         """Determine ComputedVars that need to be recalculated based on the given vars.
 
 
         Args:
         Args:
@@ -1936,33 +1885,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             Set of computed vars to include in the delta.
             Set of computed vars to include in the delta.
         """
         """
         return {
         return {
-            cvar
+            (state_name, cvar)
             for dirty_var in from_vars or self.dirty_vars
             for dirty_var in from_vars or self.dirty_vars
-            for cvar in self._computed_var_dependencies[dirty_var]
+            for state_name, cvar in self._var_dependencies.get(dirty_var, set())
             if include_backend or not self.computed_vars[cvar]._backend
             if include_backend or not self.computed_vars[cvar]._backend
         }
         }
 
 
-    @classmethod
-    def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
-        """Determine substates which could be affected by dirty vars in this state.
-
-        Returns:
-            Set of State classes that may need to be fetched to recalc computed vars.
-        """
-        # _always_dirty_substates need to be fetched to recalc computed vars.
-        fetch_substates = {
-            cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
-            for substate_name in cls._always_dirty_substates
-        }
-        for dependent_substates in cls._substate_var_dependencies.values():
-            fetch_substates.update(
-                {
-                    cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
-                    for substate_name in dependent_substates
-                }
-            )
-        return fetch_substates
-
     def get_delta(self) -> Delta:
     def get_delta(self) -> Delta:
         """Get the delta for the state.
         """Get the delta for the state.
 
 
@@ -1971,21 +1899,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         """
         delta = {}
         delta = {}
 
 
-        # Apply dirty variables down into substates
-        self.dirty_vars.update(self._always_dirty_computed_vars)
-        self._mark_dirty()
-
+        self._mark_dirty_computed_vars()
         frontend_computed_vars: set[str] = {
         frontend_computed_vars: set[str] = {
             name for name, cv in self.computed_vars.items() if not cv._backend
             name for name, cv in self.computed_vars.items() if not cv._backend
         }
         }
 
 
         # Return the dirty vars for this instance, any cached/dependent computed vars,
         # Return the dirty vars for this instance, any cached/dependent computed vars,
         # and always dirty computed vars (cache=False)
         # and always dirty computed vars (cache=False)
-        delta_vars = (
-            self.dirty_vars.intersection(self.base_vars)
-            .union(self.dirty_vars.intersection(frontend_computed_vars))
-            .union(self._dirty_computed_vars(include_backend=False))
-            .union(self._always_dirty_computed_vars)
+        delta_vars = self.dirty_vars.intersection(self.base_vars).union(
+            self.dirty_vars.intersection(frontend_computed_vars)
         )
         )
 
 
         subdelta: Dict[str, Any] = {
         subdelta: Dict[str, Any] = {
@@ -2015,23 +1937,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             self.parent_state.dirty_substates.add(self.get_name())
             self.parent_state.dirty_substates.add(self.get_name())
             self.parent_state._mark_dirty()
             self.parent_state._mark_dirty()
 
 
-        # Append expired computed vars to dirty_vars to trigger recalculation
-        self.dirty_vars.update(self._expired_computed_vars())
-
         # have to mark computed vars dirty to allow access to newly computed
         # have to mark computed vars dirty to allow access to newly computed
         # values within the same ComputedVar function
         # values within the same ComputedVar function
         self._mark_dirty_computed_vars()
         self._mark_dirty_computed_vars()
-        self._mark_dirty_substates()
-
-    def _mark_dirty_substates(self):
-        """Propagate dirty var / computed var status into substates."""
-        substates = self.substates
-        for var in self.dirty_vars:
-            for substate_name in self._substate_var_dependencies[var]:
-                self.dirty_substates.add(substate_name)
-                substate = substates[substate_name]
-                substate.dirty_vars.add(var)
-                substate._mark_dirty()
 
 
     def _update_was_touched(self):
     def _update_was_touched(self):
         """Update the _was_touched flag based on dirty_vars."""
         """Update the _was_touched flag based on dirty_vars."""
@@ -2103,11 +2011,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             The object as a dictionary.
             The object as a dictionary.
         """
         """
         if include_computed:
         if include_computed:
-            # Apply dirty variables down into substates to allow never-cached ComputedVar to
-            # trigger recalculation of dependent vars
-            self.dirty_vars.update(self._always_dirty_computed_vars)
-            self._mark_dirty()
-
+            self._mark_dirty_computed_vars()
         base_vars = {
         base_vars = {
             prop_name: self.get_value(prop_name) for prop_name in self.base_vars
             prop_name: self.get_value(prop_name) for prop_name in self.base_vars
         }
         }
@@ -2824,7 +2728,7 @@ class StateProxy(wrapt.ObjectProxy):
             await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
             await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
         )
         )
 
 
-    def _as_state_update(self, *args, **kwargs) -> StateUpdate:
+    async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
         """Temporarily allow mutability to access parent_state.
         """Temporarily allow mutability to access parent_state.
 
 
         Args:
         Args:
@@ -2837,7 +2741,7 @@ class StateProxy(wrapt.ObjectProxy):
         original_mutable = self._self_mutable
         original_mutable = self._self_mutable
         self._self_mutable = True
         self._self_mutable = True
         try:
         try:
-            return self.__wrapped__._as_state_update(*args, **kwargs)
+            return await self.__wrapped__._as_state_update(*args, **kwargs)
         finally:
         finally:
             self._self_mutable = original_mutable
             self._self_mutable = original_mutable
 
 
@@ -3313,103 +3217,106 @@ class StateManagerRedis(StateManager):
         b"evicted",
         b"evicted",
     }
     }
 
 
-    async def _get_parent_state(
-        self, token: str, state: BaseState | None = None
-    ) -> BaseState | None:
-        """Get the parent state for the state requested in the token.
+    def _get_required_state_classes(
+        self,
+        target_state_cls: Type[BaseState],
+        subclasses: bool = False,
+        required_state_classes: set[Type[BaseState]] | None = None,
+    ) -> set[Type[BaseState]]:
+        """Recursively determine which states are required to fetch the target state.
+
+        This will always include potentially dirty substates that depend on vars
+        in the target_state_cls.
 
 
         Args:
         Args:
-            token: The token to get the state for (_substate_key).
-            state: The state instance to get parent state for.
+            target_state_cls: The target state class being fetched.
+            subclasses: Whether to include subclasses of the target state.
+            required_state_classes: Recursive argument tracking state classes that have already been seen.
 
 
         Returns:
         Returns:
-            The parent state for the state requested by the token or None if there is no such parent.
-        """
-        parent_state = None
-        client_token, state_path = _split_substate_key(token)
-        parent_state_name = state_path.rpartition(".")[0]
-        if parent_state_name:
-            cached_substates = None
-            if state is not None:
-                cached_substates = [state]
-            # Retrieve the parent state to populate event handlers onto this substate.
-            parent_state = await self.get_state(
-                token=_substate_key(client_token, parent_state_name),
-                top_level=False,
-                get_substates=False,
-                cached_substates=cached_substates,
+            The set of state classes required to fetch the target state.
+        """
+        if required_state_classes is None:
+            required_state_classes = set()
+        # Get the substates if requested.
+        if subclasses:
+            for substate in target_state_cls.get_substates():
+                self._get_required_state_classes(
+                    substate,
+                    subclasses=True,
+                    required_state_classes=required_state_classes,
+                )
+        if target_state_cls in required_state_classes:
+            return required_state_classes
+        required_state_classes.add(target_state_cls)
+
+        # Get dependent substates.
+        for pd_substates in target_state_cls._get_potentially_dirty_states():
+            self._get_required_state_classes(
+                pd_substates,
+                subclasses=False,
+                required_state_classes=required_state_classes,
             )
             )
-        return parent_state
 
 
-    async def _populate_substates(
-        self,
-        token: str,
-        state: BaseState,
-        all_substates: bool = False,
-    ):
-        """Fetch and link substates for the given state instance.
+        # Get the parent state if it exists.
+        if parent_state := target_state_cls.get_parent_state():
+            self._get_required_state_classes(
+                parent_state,
+                subclasses=False,
+                required_state_classes=required_state_classes,
+            )
+        return required_state_classes
 
 
-        There is no return value; the side-effect is that `state` will have `substates` populated,
-        and each substate will have its `parent_state` set to `state`.
+    def _get_populated_states(
+        self,
+        target_state: BaseState,
+        populated_states: dict[str, BaseState] | None = None,
+    ) -> dict[str, BaseState]:
+        """Recursively determine which states from target_state are already fetched.
 
 
         Args:
         Args:
-            token: The token to get the state for.
-            state: The state instance to populate substates for.
-            all_substates: Whether to fetch all substates or just required substates.
-        """
-        client_token, _ = _split_substate_key(token)
+            target_state: The state to check for populated states.
+            populated_states: Recursive argument tracking states seen in previous calls.
 
 
-        if all_substates:
-            # All substates are requested.
-            fetch_substates = state.get_substates()
-        else:
-            # Only _potentially_dirty_substates need to be fetched to recalc computed vars.
-            fetch_substates = state._potentially_dirty_substates()
-
-        tasks = {}
-        # Retrieve the necessary substates from redis.
-        for substate_cls in fetch_substates:
-            if substate_cls.get_name() in state.substates:
-                continue
-            substate_name = substate_cls.get_name()
-            tasks[substate_name] = asyncio.create_task(
-                self.get_state(
-                    token=_substate_key(client_token, substate_cls),
-                    top_level=False,
-                    get_substates=all_substates,
-                    parent_state=state,
-                )
+        Returns:
+            A dictionary of state full name to state instance.
+        """
+        if populated_states is None:
+            populated_states = {}
+        if target_state.get_full_name() in populated_states:
+            return populated_states
+        populated_states[target_state.get_full_name()] = target_state
+        for substate in target_state.substates.values():
+            self._get_populated_states(substate, populated_states=populated_states)
+        if target_state.parent_state is not None:
+            self._get_populated_states(
+                target_state.parent_state, populated_states=populated_states
             )
             )
-
-        for substate_name, substate_task in tasks.items():
-            state.substates[substate_name] = await substate_task
+        return populated_states
 
 
     @override
     @override
     async def get_state(
     async def get_state(
         self,
         self,
         token: str,
         token: str,
         top_level: bool = True,
         top_level: bool = True,
-        get_substates: bool = True,
-        parent_state: BaseState | None = None,
-        cached_substates: list[BaseState] | None = None,
+        for_state_instance: BaseState | None = None,
     ) -> BaseState:
     ) -> BaseState:
         """Get the state for a token.
         """Get the state for a token.
 
 
         Args:
         Args:
             token: The token to get the state for.
             token: The token to get the state for.
             top_level: If true, return an instance of the top-level state (self.state).
             top_level: If true, return an instance of the top-level state (self.state).
-            get_substates: If true, also retrieve substates.
-            parent_state: If provided, use this parent_state instead of getting it from redis.
-            cached_substates: If provided, attach these substates to the state.
+            for_state_instance: If provided, attach the requested states to this existing state tree.
 
 
         Returns:
         Returns:
             The state for the token.
             The state for the token.
 
 
         Raises:
         Raises:
-            RuntimeError: when the state_cls is not specified in the token
+            RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
+                requested state was not fetched.
         """
         """
         # Split the actual token from the fully qualified substate name.
         # Split the actual token from the fully qualified substate name.
-        _, state_path = _split_substate_key(token)
+        token, state_path = _split_substate_key(token)
         if state_path:
         if state_path:
             # Get the State class associated with the given path.
             # Get the State class associated with the given path.
             state_cls = self.state.get_class_substate(state_path)
             state_cls = self.state.get_class_substate(state_path)
@@ -3418,43 +3325,59 @@ class StateManagerRedis(StateManager):
                 f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
                 f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
             )
             )
 
 
-        # The deserialized or newly created (sub)state instance.
-        state = None
-
-        # Fetch the serialized substate from redis.
-        redis_state = await self.redis.get(token)
-
-        if redis_state is not None:
-            # Deserialize the substate.
-            with contextlib.suppress(StateSchemaMismatchError):
-                state = BaseState._deserialize(data=redis_state)
-        if state is None:
-            # Key didn't exist or schema mismatch so create a new instance for this token.
-            state = state_cls(
-                init_substates=False,
-                _reflex_internal_init=True,
-            )
-        # Populate parent state if missing and requested.
-        if parent_state is None:
-            parent_state = await self._get_parent_state(token, state)
-        # Set up Bidirectional linkage between this state and its parent.
-        if parent_state is not None:
-            parent_state.substates[state.get_name()] = state
-            state.parent_state = parent_state
-        # Avoid fetching substates multiple times.
-        if cached_substates:
-            for substate in cached_substates:
-                state.substates[substate.get_name()] = substate
-                if substate.parent_state is None:
-                    substate.parent_state = state
-        # Populate substates if requested.
-        await self._populate_substates(token, state, all_substates=get_substates)
+        # Determine which states we already have.
+        flat_state_tree: dict[str, BaseState] = (
+            self._get_populated_states(for_state_instance) if for_state_instance else {}
+        )
+
+        # Determine which states from the tree need to be fetched.
+        required_state_classes = sorted(
+            self._get_required_state_classes(state_cls, subclasses=True)
+            - {type(s) for s in flat_state_tree.values()},
+            key=lambda x: x.get_full_name(),
+        )
+
+        redis_pipeline = self.redis.pipeline()
+        for state_cls in required_state_classes:
+            redis_pipeline.get(_substate_key(token, state_cls))
+
+        for state_cls, redis_state in zip(
+            required_state_classes,
+            await redis_pipeline.execute(),
+            strict=False,
+        ):
+            state = None
+
+            if redis_state is not None:
+                # Deserialize the substate.
+                with contextlib.suppress(StateSchemaMismatchError):
+                    state = BaseState._deserialize(data=redis_state)
+            if state is None:
+                # Key didn't exist or schema mismatch so create a new instance for this token.
+                state = state_cls(
+                    init_substates=False,
+                    _reflex_internal_init=True,
+                )
+            flat_state_tree[state.get_full_name()] = state
+            if state.get_parent_state() is not None:
+                parent_state_name, _dot, state_name = state.get_full_name().rpartition(
+                    "."
+                )
+                parent_state = flat_state_tree.get(parent_state_name)
+                if parent_state is None:
+                    raise RuntimeError(
+                        f"Parent state for {state.get_full_name()} was not found "
+                        "in the state tree, but should have already been fetched. "
+                        "This is a bug",
+                    )
+                parent_state.substates[state_name] = state
+                state.parent_state = parent_state
 
 
         # To retain compatibility with previous implementation, by default, we return
         # To retain compatibility with previous implementation, by default, we return
-        # the top-level state by chasing `parent_state` pointers up the tree.
+        # the top-level state which should always be fetched or already cached.
         if top_level:
         if top_level:
-            return state._get_root_state()
-        return state
+            return flat_state_tree[self.state.get_full_name()]
+        return flat_state_tree[state_cls.get_full_name()]
 
 
     @override
     @override
     async def set_state(
     async def set_state(
@@ -4154,12 +4077,19 @@ def reload_state_module(
         state: Recursive argument for the state class to reload.
         state: Recursive argument for the state class to reload.
 
 
     """
     """
+    # Clean out all potentially dirty states of reloaded modules.
+    for pd_state in tuple(state._potentially_dirty_states):
+        with contextlib.suppress(ValueError):
+            if (
+                state.get_root_state().get_class_substate(pd_state).__module__ == module
+                and module is not None
+            ):
+                state._potentially_dirty_states.remove(pd_state)
     for subclass in tuple(state.class_subclasses):
     for subclass in tuple(state.class_subclasses):
         reload_state_module(module=module, state=subclass)
         reload_state_module(module=module, state=subclass)
         if subclass.__module__ == module and module is not None:
         if subclass.__module__ == module and module is not None:
             state.class_subclasses.remove(subclass)
             state.class_subclasses.remove(subclass)
             state._always_dirty_substates.discard(subclass.get_name())
             state._always_dirty_substates.discard(subclass.get_name())
-            state._computed_var_dependencies = defaultdict(set)
-            state._substate_var_dependencies = defaultdict(set)
+            state._var_dependencies = {}
             state._init_var_dependency_dicts()
             state._init_var_dependency_dicts()
     state.get_class_substate.cache_clear()
     state.get_class_substate.cache_clear()

+ 1 - 1
reflex/utils/exec.py

@@ -488,7 +488,7 @@ def output_system_info():
         dependencies.append(fnm_info)
         dependencies.append(fnm_info)
 
 
     if system == "Linux":
     if system == "Linux":
-        import distro
+        import distro  # pyright: ignore[reportMissingImports]
 
 
         os_version = distro.name(pretty=True)
         os_version = distro.name(pretty=True)
     else:
     else:

+ 238 - 109
reflex/vars/base.py

@@ -5,7 +5,6 @@ from __future__ import annotations
 import contextlib
 import contextlib
 import dataclasses
 import dataclasses
 import datetime
 import datetime
-import dis
 import functools
 import functools
 import inspect
 import inspect
 import json
 import json
@@ -20,6 +19,7 @@ from typing import (
     Any,
     Any,
     Callable,
     Callable,
     ClassVar,
     ClassVar,
+    Coroutine,
     Dict,
     Dict,
     FrozenSet,
     FrozenSet,
     Generic,
     Generic,
@@ -51,7 +51,6 @@ from reflex.utils.exceptions import (
     VarAttributeError,
     VarAttributeError,
     VarDependencyError,
     VarDependencyError,
     VarTypeError,
     VarTypeError,
-    VarValueError,
 )
 )
 from reflex.utils.format import format_state_name
 from reflex.utils.format import format_state_name
 from reflex.utils.imports import (
 from reflex.utils.imports import (
@@ -1983,7 +1982,7 @@ class ComputedVar(Var[RETURN_TYPE]):
     _initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
     _initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
 
 
     # Explicit var dependencies to track
     # Explicit var dependencies to track
-    _static_deps: set[str] = dataclasses.field(default_factory=set)
+    _static_deps: dict[str, set[str]] = dataclasses.field(default_factory=dict)
 
 
     # Whether var dependencies should be auto-determined
     # Whether var dependencies should be auto-determined
     _auto_deps: bool = dataclasses.field(default=True)
     _auto_deps: bool = dataclasses.field(default=True)
@@ -2053,21 +2052,34 @@ class ComputedVar(Var[RETURN_TYPE]):
 
 
         object.__setattr__(self, "_update_interval", interval)
         object.__setattr__(self, "_update_interval", interval)
 
 
-        if deps is None:
-            deps = []
-        else:
+        _static_deps = {}
+        if isinstance(deps, dict):
+            # Assume a dict is coming from _replace, so no special processing.
+            _static_deps = deps
+        elif deps is not None:
             for dep in deps:
             for dep in deps:
                 if isinstance(dep, Var):
                 if isinstance(dep, Var):
-                    continue
-                if isinstance(dep, str) and dep != "":
-                    continue
-                raise TypeError(
-                    "ComputedVar dependencies must be Var instances or var names (non-empty strings)."
-                )
+                    state_name = (
+                        all_var_data.state
+                        if (all_var_data := dep._get_all_var_data())
+                        and all_var_data.state
+                        else None
+                    )
+                    if all_var_data is not None:
+                        var_name = all_var_data.field_name
+                    else:
+                        var_name = dep._js_expr
+                    _static_deps.setdefault(state_name, set()).add(var_name)
+                elif isinstance(dep, str) and dep != "":
+                    _static_deps.setdefault(None, set()).add(dep)
+                else:
+                    raise TypeError(
+                        "ComputedVar dependencies must be Var instances or var names (non-empty strings)."
+                    )
         object.__setattr__(
         object.__setattr__(
             self,
             self,
             "_static_deps",
             "_static_deps",
-            {dep._js_expr if isinstance(dep, Var) else dep for dep in deps},
+            _static_deps,
         )
         )
         object.__setattr__(self, "_auto_deps", auto_deps)
         object.__setattr__(self, "_auto_deps", auto_deps)
 
 
@@ -2149,6 +2161,13 @@ class ComputedVar(Var[RETURN_TYPE]):
             return True
             return True
         return datetime.datetime.now() - last_updated > self._update_interval
         return datetime.datetime.now() - last_updated > self._update_interval
 
 
+    @overload
+    def __get__(
+        self: ComputedVar[bool],
+        instance: None,
+        owner: Type,
+    ) -> BooleanVar: ...
+
     @overload
     @overload
     def __get__(
     def __get__(
         self: ComputedVar[int] | ComputedVar[float],
         self: ComputedVar[int] | ComputedVar[float],
@@ -2233,125 +2252,67 @@ class ComputedVar(Var[RETURN_TYPE]):
                 setattr(instance, self._last_updated_attr, datetime.datetime.now())
                 setattr(instance, self._last_updated_attr, datetime.datetime.now())
             value = getattr(instance, self._cache_attr)
             value = getattr(instance, self._cache_attr)
 
 
+        self._check_deprecated_return_type(instance, value)
+
+        return value
+
+    def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None:
         if not _isinstance(value, self._var_type):
         if not _isinstance(value, self._var_type):
             console.error(
             console.error(
                 f"Computed var '{type(instance).__name__}.{self._js_expr}' must return"
                 f"Computed var '{type(instance).__name__}.{self._js_expr}' must return"
                 f" type '{self._var_type}', got '{type(value)}'."
                 f" type '{self._var_type}', got '{type(value)}'."
             )
             )
 
 
-        return value
-
     def _deps(
     def _deps(
         self,
         self,
-        objclass: Type,
+        objclass: Type[BaseState],
         obj: FunctionType | CodeType | None = None,
         obj: FunctionType | CodeType | None = None,
-        self_name: Optional[str] = None,
-    ) -> set[str]:
+    ) -> dict[str, set[str]]:
         """Determine var dependencies of this ComputedVar.
         """Determine var dependencies of this ComputedVar.
 
 
-        Save references to attributes accessed on "self".  Recursively called
-        when the function makes a method call on "self" or define comprehensions
-        or nested functions that may reference "self".
+        Save references to attributes accessed on "self" or other fetched states.
+
+        Recursively called when the function makes a method call on "self" or
+        define comprehensions or nested functions that may reference "self".
 
 
         Args:
         Args:
             objclass: the class obj this ComputedVar is attached to.
             objclass: the class obj this ComputedVar is attached to.
             obj: the object to disassemble (defaults to the fget function).
             obj: the object to disassemble (defaults to the fget function).
-            self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
 
 
         Returns:
         Returns:
-            A set of variable names accessed by the given obj.
-
-        Raises:
-            VarValueError: if the function references the get_state, parent_state, or substates attributes
-                (cannot track deps in a related state, only implicitly via parent state).
+            A dictionary mapping state names to the set of variable names
+            accessed by the given obj.
         """
         """
+        from .dep_tracking import DependencyTracker
+
+        d = {}
+        if self._static_deps:
+            d.update(self._static_deps)
+            # None is a placeholder for the current state class.
+            if None in d:
+                d[objclass.get_full_name()] = d.pop(None)
+
         if not self._auto_deps:
         if not self._auto_deps:
-            return self._static_deps
-        d = self._static_deps.copy()
+            return d
+
         if obj is None:
         if obj is None:
             fget = self._fget
             fget = self._fget
             if fget is not None:
             if fget is not None:
                 obj = cast(FunctionType, fget)
                 obj = cast(FunctionType, fget)
             else:
             else:
-                return set()
-        with contextlib.suppress(AttributeError):
-            # unbox functools.partial
-            obj = cast(FunctionType, obj.func)  # pyright: ignore [reportAttributeAccessIssue]
-        with contextlib.suppress(AttributeError):
-            # unbox EventHandler
-            obj = cast(FunctionType, obj.fn)  # pyright: ignore [reportAttributeAccessIssue]
+                return d
 
 
-        if self_name is None and isinstance(obj, FunctionType):
-            try:
-                # the first argument to the function is the name of "self" arg
-                self_name = obj.__code__.co_varnames[0]
-            except (AttributeError, IndexError):
-                self_name = None
-        if self_name is None:
-            # cannot reference attributes on self if method takes no args
-            return set()
-
-        invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
-        self_is_top_of_stack = False
-        for instruction in dis.get_instructions(obj):
-            if (
-                instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
-                and instruction.argval == self_name
-            ):
-                # bytecode loaded the class instance to the top of stack, next load instruction
-                # is referencing an attribute on self
-                self_is_top_of_stack = True
-                continue
-            if self_is_top_of_stack and instruction.opname in (
-                "LOAD_ATTR",
-                "LOAD_METHOD",
-            ):
-                try:
-                    ref_obj = getattr(objclass, instruction.argval)
-                except Exception:
-                    ref_obj = None
-                if instruction.argval in invalid_names:
-                    raise VarValueError(
-                        f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
-                    )
-                if callable(ref_obj):
-                    # recurse into callable attributes
-                    d.update(
-                        self._deps(
-                            objclass=objclass,
-                            obj=ref_obj,  # pyright: ignore [reportArgumentType]
-                        )
-                    )
-                # recurse into property fget functions
-                elif isinstance(ref_obj, property) and not isinstance(
-                    ref_obj, ComputedVar
-                ):
-                    d.update(
-                        self._deps(
-                            objclass=objclass,
-                            obj=ref_obj.fget,  # pyright: ignore [reportArgumentType]
-                        )
-                    )
-                elif (
-                    instruction.argval in objclass.backend_vars
-                    or instruction.argval in objclass.vars
-                ):
-                    # var access
-                    d.add(instruction.argval)
-            elif instruction.opname == "LOAD_CONST" and isinstance(
-                instruction.argval, CodeType
-            ):
-                # recurse into nested functions / comprehensions, which can reference
-                # instance attributes from the outer scope
-                d.update(
-                    self._deps(
-                        objclass=objclass,
-                        obj=instruction.argval,
-                        self_name=self_name,
-                    )
-                )
-            self_is_top_of_stack = False
-        return d
+        try:
+            return DependencyTracker(
+                func=obj, state_cls=objclass, dependencies=d
+            ).dependencies
+        except Exception as e:
+            console.warn(
+                "Failed to automatically determine dependencies for computed var "
+                f"{objclass.__name__}.{self._js_expr}: {e}. "
+                "Provide static_deps and set auto_deps=False to suppress this warning."
+            )
+            return d
 
 
     def mark_dirty(self, instance: BaseState) -> None:
     def mark_dirty(self, instance: BaseState) -> None:
         """Mark this ComputedVar as dirty.
         """Mark this ComputedVar as dirty.
@@ -2362,6 +2323,37 @@ class ComputedVar(Var[RETURN_TYPE]):
         with contextlib.suppress(AttributeError):
         with contextlib.suppress(AttributeError):
             delattr(instance, self._cache_attr)
             delattr(instance, self._cache_attr)
 
 
+    def add_dependency(self, objclass: Type[BaseState], dep: Var):
+        """Explicitly add a dependency to the ComputedVar.
+
+        After adding the dependency, when the `dep` changes, this computed var
+        will be marked dirty.
+
+        Args:
+            objclass: The class obj this ComputedVar is attached to.
+            dep: The dependency to add.
+
+        Raises:
+            VarDependencyError: If the dependency is not a Var instance with a
+                state and field name
+        """
+        if all_var_data := dep._get_all_var_data():
+            state_name = all_var_data.state
+            if state_name:
+                var_name = all_var_data.field_name
+                if var_name:
+                    self._static_deps.setdefault(state_name, set()).add(var_name)
+                    objclass.get_root_state().get_class_substate(
+                        state_name
+                    )._var_dependencies.setdefault(var_name, set()).add(
+                        (objclass.get_full_name(), self._js_expr)
+                    )
+                    return
+        raise VarDependencyError(
+            "ComputedVar dependencies must be Var instances with a state and "
+            f"field name, got {dep!r}."
+        )
+
     def _determine_var_type(self) -> Type:
     def _determine_var_type(self) -> Type:
         """Get the type of the var.
         """Get the type of the var.
 
 
@@ -2398,6 +2390,126 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
     pass
     pass
 
 
 
 
+async def _default_async_computed_var(_self: BaseState) -> Any:
+    return None
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    init=False,
+    slots=True,
+)
+class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
+    """A computed var that wraps a coroutinefunction."""
+
+    _fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = (
+        dataclasses.field(default=_default_async_computed_var)
+    )
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[bool],
+        instance: None,
+        owner: Type,
+    ) -> BooleanVar: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[int] | ComputedVar[float],
+        instance: None,
+        owner: Type,
+    ) -> NumberVar: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[str],
+        instance: None,
+        owner: Type,
+    ) -> StringVar: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]],
+        instance: None,
+        owner: Type,
+    ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[list[LIST_INSIDE]],
+        instance: None,
+        owner: Type,
+    ) -> ArrayVar[list[LIST_INSIDE]]: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[tuple[LIST_INSIDE, ...]],
+        instance: None,
+        owner: Type,
+    ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
+
+    @overload
+    def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ...
+
+    @overload
+    def __get__(
+        self, instance: BaseState, owner: Type
+    ) -> Coroutine[None, None, RETURN_TYPE]: ...
+
+    def __get__(
+        self, instance: BaseState | None, owner
+    ) -> Var | Coroutine[None, None, RETURN_TYPE]:
+        """Get the ComputedVar value.
+
+        If the value is already cached on the instance, return the cached value.
+
+        Args:
+            instance: the instance of the class accessing this computed var.
+            owner: the class that this descriptor is attached to.
+
+        Returns:
+            The value of the var for the given instance.
+        """
+        if instance is None:
+            return super(AsyncComputedVar, self).__get__(instance, owner)
+
+        if not self._cache:
+
+            async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
+                value = await self.fget(instance)
+                self._check_deprecated_return_type(instance, value)
+                return value
+
+            return _awaitable_result()
+        else:
+            # handle caching
+            async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
+                if not hasattr(instance, self._cache_attr) or self.needs_update(
+                    instance
+                ):
+                    # Set cache attr on state instance.
+                    setattr(instance, self._cache_attr, await self.fget(instance))
+                    # Ensure the computed var gets serialized to redis.
+                    instance._was_touched = True
+                    # Set the last updated timestamp on the state instance.
+                    setattr(instance, self._last_updated_attr, datetime.datetime.now())
+                value = getattr(instance, self._cache_attr)
+                self._check_deprecated_return_type(instance, value)
+                return value
+
+            return _awaitable_result()
+
+    @property
+    def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]:
+        """Get the getter function.
+
+        Returns:
+            The getter function.
+        """
+        return self._fget
+
+
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
     BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
 
 
@@ -2464,10 +2576,27 @@ def computed_var(
         raise VarDependencyError("Cannot track dependencies without caching.")
         raise VarDependencyError("Cannot track dependencies without caching.")
 
 
     if fget is not None:
     if fget is not None:
-        return ComputedVar(fget, cache=cache)
+        if inspect.iscoroutinefunction(fget):
+            computed_var_cls = AsyncComputedVar
+        else:
+            computed_var_cls = ComputedVar
+        return computed_var_cls(
+            fget,
+            initial_value=initial_value,
+            cache=cache,
+            deps=deps,
+            auto_deps=auto_deps,
+            interval=interval,
+            backend=backend,
+            **kwargs,
+        )
 
 
     def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
     def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
-        return ComputedVar(
+        if inspect.iscoroutinefunction(fget):
+            computed_var_cls = AsyncComputedVar
+        else:
+            computed_var_cls = ComputedVar
+        return computed_var_cls(
             fget,
             fget,
             initial_value=initial_value,
             initial_value=initial_value,
             cache=cache,
             cache=cache,

+ 344 - 0
reflex/vars/dep_tracking.py

@@ -0,0 +1,344 @@
+"""Collection of base classes."""
+
+from __future__ import annotations
+
+import contextlib
+import dataclasses
+import dis
+import enum
+import inspect
+from types import CellType, CodeType, FunctionType
+from typing import TYPE_CHECKING, Any, ClassVar, Type, cast
+
+from reflex.utils.exceptions import VarValueError
+
+if TYPE_CHECKING:
+    from reflex.state import BaseState
+
+    from .base import Var
+
+
+CellEmpty = object()
+
+
+def get_cell_value(cell: CellType) -> Any:
+    """Get the value of a cell object.
+
+    Args:
+        cell: The cell object to get the value from. (func.__closure__ objects)
+
+    Returns:
+        The value from the cell or CellEmpty if a ValueError is raised.
+    """
+    try:
+        return cell.cell_contents
+    except ValueError:
+        return CellEmpty
+
+
+class ScanStatus(enum.Enum):
+    """State of the dis instruction scanning loop."""
+
+    SCANNING = enum.auto()
+    GETTING_ATTR = enum.auto()
+    GETTING_STATE = enum.auto()
+    GETTING_VAR = enum.auto()
+
+
+@dataclasses.dataclass
+class DependencyTracker:
+    """State machine for identifying state attributes that are accessed by a function."""
+
+    func: FunctionType | CodeType = dataclasses.field()
+    state_cls: Type[BaseState] = dataclasses.field()
+
+    dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
+
+    scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
+    top_of_stack: str | None = dataclasses.field(default=None)
+
+    tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict)
+
+    _getting_state_class: Type[BaseState] | None = dataclasses.field(default=None)
+    _getting_var_instructions: list[dis.Instruction] = dataclasses.field(
+        default_factory=list
+    )
+
+    INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
+
+    def __post_init__(self):
+        """After initializing, populate the dependencies dict."""
+        with contextlib.suppress(AttributeError):
+            # unbox functools.partial
+            self.func = cast(FunctionType, self.func.func)  # pyright: ignore[reportAttributeAccessIssue]
+        with contextlib.suppress(AttributeError):
+            # unbox EventHandler
+            self.func = cast(FunctionType, self.func.fn)  # pyright: ignore[reportAttributeAccessIssue]
+
+        if isinstance(self.func, FunctionType):
+            with contextlib.suppress(AttributeError, IndexError):
+                # the first argument to the function is the name of "self" arg
+                self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
+
+        self._populate_dependencies()
+
+    def _merge_deps(self, tracker: DependencyTracker) -> None:
+        """Merge dependencies from another DependencyTracker.
+
+        Args:
+            tracker: The DependencyTracker to merge dependencies from.
+        """
+        for state_name, dep_name in tracker.dependencies.items():
+            self.dependencies.setdefault(state_name, set()).update(dep_name)
+
+    def load_attr_or_method(self, instruction: dis.Instruction) -> None:
+        """Handle loading an attribute or method from the object on top of the stack.
+
+        This method directly tracks attributes and recursively merges
+        dependencies from analyzing the dependencies of any methods called.
+
+        Args:
+            instruction: The dis instruction to process.
+
+        Raises:
+            VarValueError: if the attribute is an disallowed name.
+        """
+        from .base import ComputedVar
+
+        if instruction.argval in self.INVALID_NAMES:
+            raise VarValueError(
+                f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
+            )
+        if instruction.argval == "get_state":
+            # Special case: arbitrary state access requested.
+            self.scan_status = ScanStatus.GETTING_STATE
+            return
+        if instruction.argval == "get_var_value":
+            # Special case: arbitrary var access requested.
+            self.scan_status = ScanStatus.GETTING_VAR
+            return
+
+        # Reset status back to SCANNING after attribute is accessed.
+        self.scan_status = ScanStatus.SCANNING
+        if not self.top_of_stack:
+            return
+        target_state = self.tracked_locals[self.top_of_stack]
+        try:
+            ref_obj = getattr(target_state, instruction.argval)
+        except AttributeError:
+            # Not found on this state class, maybe it is a dynamic attribute that will be picked up later.
+            ref_obj = None
+
+        if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
+            # recurse into property fget functions
+            ref_obj = ref_obj.fget
+        if callable(ref_obj):
+            # recurse into callable attributes
+            self._merge_deps(
+                type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
+            )
+        elif (
+            instruction.argval in target_state.backend_vars
+            or instruction.argval in target_state.vars
+        ):
+            # var access
+            self.dependencies.setdefault(target_state.get_full_name(), set()).add(
+                instruction.argval
+            )
+
+    def _get_globals(self) -> dict[str, Any]:
+        """Get the globals of the function.
+
+        Returns:
+            The var names and values in the globals of the function.
+        """
+        if isinstance(self.func, CodeType):
+            return {}
+        return self.func.__globals__  # pyright: ignore[reportAttributeAccessIssue]
+
+    def _get_closure(self) -> dict[str, Any]:
+        """Get the closure of the function, with unbound values omitted.
+
+        Returns:
+            The var names and values in the closure of the function.
+        """
+        if isinstance(self.func, CodeType):
+            return {}
+        return {
+            var_name: get_cell_value(cell)
+            for var_name, cell in zip(
+                self.func.__code__.co_freevars,  # pyright: ignore[reportAttributeAccessIssue]
+                self.func.__closure__ or (),
+                strict=False,
+            )
+            if get_cell_value(cell) is not CellEmpty
+        }
+
+    def handle_getting_state(self, instruction: dis.Instruction) -> None:
+        """Handle bytecode analysis when `get_state` was called in the function.
+
+        If the wrapped function is getting an arbitrary state and saving it to a
+        local variable, this method associates the local variable name with the
+        state class in self.tracked_locals.
+
+        When an attribute/method is accessed on a tracked local, it will be
+        associated with this state.
+
+        Args:
+            instruction: The dis instruction to process.
+
+        Raises:
+            VarValueError: if the state class cannot be determined from the instruction.
+        """
+        from reflex.state import BaseState
+
+        if instruction.opname == "LOAD_FAST":
+            raise VarValueError(
+                f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
+            )
+        if isinstance(self.func, CodeType):
+            raise VarValueError(
+                "Dependency detection cannot identify get_state class from a code object."
+            )
+        if instruction.opname == "LOAD_GLOBAL":
+            # Special case: referencing state class from global scope.
+            try:
+                self._getting_state_class = self._get_globals()[instruction.argval]
+            except (ValueError, KeyError) as ve:
+                raise VarValueError(
+                    f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
+                ) from ve
+        elif instruction.opname == "LOAD_DEREF":
+            # Special case: referencing state class from closure.
+            try:
+                self._getting_state_class = self._get_closure()[instruction.argval]
+            except (ValueError, KeyError) as ve:
+                raise VarValueError(
+                    f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
+                ) from ve
+        elif instruction.opname == "STORE_FAST":
+            # Storing the result of get_state in a local variable.
+            if not isinstance(self._getting_state_class, type) or not issubclass(
+                self._getting_state_class, BaseState
+            ):
+                raise VarValueError(
+                    f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
+                )
+            self.tracked_locals[instruction.argval] = self._getting_state_class
+            self.scan_status = ScanStatus.SCANNING
+            self._getting_state_class = None
+
+    def _eval_var(self) -> Var:
+        """Evaluate instructions from the wrapped function to get the Var object.
+
+        Returns:
+            The Var object.
+
+        Raises:
+            VarValueError: if the source code for the var cannot be determined.
+        """
+        # Get the original source code and eval it to get the Var.
+        module = inspect.getmodule(self.func)
+        positions0 = self._getting_var_instructions[0].positions
+        positions1 = self._getting_var_instructions[-1].positions
+        if module is None or positions0 is None or positions1 is None:
+            raise VarValueError(
+                f"Cannot determine the source code for the var in {self.func!r}."
+            )
+        start_line = positions0.lineno
+        start_column = positions0.col_offset
+        end_line = positions1.end_lineno
+        end_column = positions1.end_col_offset
+        if (
+            start_line is None
+            or start_column is None
+            or end_line is None
+            or end_column is None
+        ):
+            raise VarValueError(
+                f"Cannot determine the source code for the var in {self.func!r}."
+            )
+        source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line]
+        # Create a python source string snippet.
+        if len(source) > 1:
+            snipped_source = "".join(
+                [
+                    *source[0][start_column:],
+                    *(source[1:-2] if len(source) > 2 else []),
+                    *source[-1][: end_column - 1],
+                ]
+            )
+        else:
+            snipped_source = source[0][start_column : end_column - 1]
+        # Evaluate the string in the context of the function's globals and closure.
+        return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
+
+    def handle_getting_var(self, instruction: dis.Instruction) -> None:
+        """Handle bytecode analysis when `get_var_value` was called in the function.
+
+        This only really works if the expression passed to `get_var_value` is
+        evaluable in the function's global scope or closure, so getting the var
+        value from a var saved in a local variable or in the class instance is
+        not possible.
+
+        Args:
+            instruction: The dis instruction to process.
+
+        Raises:
+            VarValueError: if the source code for the var cannot be determined.
+        """
+        if instruction.opname == "CALL" and self._getting_var_instructions:
+            if self._getting_var_instructions:
+                the_var = self._eval_var()
+                the_var_data = the_var._get_all_var_data()
+                if the_var_data is None:
+                    raise VarValueError(
+                        f"Cannot determine the source code for the var in {self.func!r}."
+                    )
+                self.dependencies.setdefault(the_var_data.state, set()).add(
+                    the_var_data.field_name
+                )
+            self._getting_var_instructions.clear()
+            self.scan_status = ScanStatus.SCANNING
+        else:
+            self._getting_var_instructions.append(instruction)
+
+    def _populate_dependencies(self) -> None:
+        """Update self.dependencies based on the disassembly of self.func.
+
+        Save references to attributes accessed on "self" or other fetched states.
+
+        Recursively called when the function makes a method call on "self" or
+        define comprehensions or nested functions that may reference "self".
+        """
+        for instruction in dis.get_instructions(self.func):
+            if self.scan_status == ScanStatus.GETTING_STATE:
+                self.handle_getting_state(instruction)
+            elif self.scan_status == ScanStatus.GETTING_VAR:
+                self.handle_getting_var(instruction)
+            elif (
+                instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
+                and instruction.argval in self.tracked_locals
+            ):
+                # bytecode loaded the class instance to the top of stack, next load instruction
+                # is referencing an attribute on self
+                self.top_of_stack = instruction.argval
+                self.scan_status = ScanStatus.GETTING_ATTR
+            elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
+                "LOAD_ATTR",
+                "LOAD_METHOD",
+            ):
+                self.load_attr_or_method(instruction)
+                self.top_of_stack = None
+            elif instruction.opname == "LOAD_CONST" and isinstance(
+                instruction.argval, CodeType
+            ):
+                # recurse into nested functions / comprehensions, which can reference
+                # instance attributes from the outer scope
+                self._merge_deps(
+                    type(self)(
+                        func=instruction.argval,
+                        state_cls=self.state_cls,
+                        tracked_locals=self.tracked_locals,
+                    )
+                )

+ 7 - 5
tests/integration/tests_playwright/test_table.py

@@ -3,7 +3,7 @@
 from typing import Generator
 from typing import Generator
 
 
 import pytest
 import pytest
-from playwright.sync_api import Page
+from playwright.sync_api import Page, expect
 
 
 from reflex.testing import AppHarness
 from reflex.testing import AppHarness
 
 
@@ -87,12 +87,14 @@ def test_table(page: Page, table_app: AppHarness):
     table = page.get_by_role("table")
     table = page.get_by_role("table")
 
 
     # Check column headers
     # Check column headers
-    headers = table.get_by_role("columnheader").all_inner_texts()
-    assert headers == expected_col_headers
+    headers = table.get_by_role("columnheader")
+    for header, exp_value in zip(headers.all(), expected_col_headers, strict=True):
+        expect(header).to_have_text(exp_value)
 
 
     # Check rows headers
     # Check rows headers
-    rows = table.get_by_role("rowheader").all_inner_texts()
-    assert rows == expected_row_headers
+    rows = table.get_by_role("rowheader")
+    for row, expected_row in zip(rows.all(), expected_row_headers, strict=True):
+        expect(row).to_have_text(expected_row)
 
 
     # Check cells
     # Check cells
     rows = table.get_by_role("cell").all_inner_texts()
     rows = table.get_by_role("cell").all_inner_texts()

+ 14 - 4
tests/units/test_app.py

@@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
     assert app._pages.keys() == {"test/[dynamic]"}
     assert app._pages.keys() == {"test/[dynamic]"}
     assert "dynamic" in app._state.computed_vars
     assert "dynamic" in app._state.computed_vars
     assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
     assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
-        constants.ROUTER
+        EmptyState.get_full_name(): {constants.ROUTER},
     }
     }
-    assert constants.ROUTER in app._state()._computed_var_dependencies
+    assert constants.ROUTER in app._state()._var_dependencies
 
 
 
 
 def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
 def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
@@ -995,9 +995,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     assert arg_name in app._state.vars
     assert arg_name in app._state.vars
     assert arg_name in app._state.computed_vars
     assert arg_name in app._state.computed_vars
     assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
     assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
-        constants.ROUTER
+        DynamicState.get_full_name(): {constants.ROUTER},
     }
     }
-    assert constants.ROUTER in app._state()._computed_var_dependencies
+    assert constants.ROUTER in app._state()._var_dependencies
 
 
     substate_token = _substate_key(token, DynamicState)
     substate_token = _substate_key(token, DynamicState)
     sid = "mock_sid"
     sid = "mock_sid"
@@ -1555,6 +1555,16 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
         def bar(self) -> str:
         def bar(self) -> str:
             return "bar"
             return "bar"
 
 
+    class Child1(ValidDepState):
+        @computed_var(deps=["base", ValidDepState.bar])
+        def other(self) -> str:
+            return "other"
+
+    class Child2(ValidDepState):
+        @computed_var(deps=["base", Child1.other])
+        def other(self) -> str:
+            return "other"
+
     app._state = ValidDepState
     app._state = ValidDepState
     app._compile()
     app._compile()
 
 

+ 174 - 38
tests/units/test_state.py

@@ -14,6 +14,7 @@ from typing import (
     Any,
     Any,
     AsyncGenerator,
     AsyncGenerator,
     Callable,
     Callable,
+    ClassVar,
     Dict,
     Dict,
     List,
     List,
     Optional,
     Optional,
@@ -1169,13 +1170,17 @@ def test_conditional_computed_vars():
 
 
     ms = MainState()
     ms = MainState()
     # Initially there are no dirty computed vars.
     # Initially there are no dirty computed vars.
-    assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
-    assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
-    assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
+    assert ms._dirty_computed_vars(from_vars={"flag"}) == {
+        (MainState.get_full_name(), "rendered_var")
+    }
+    assert ms._dirty_computed_vars(from_vars={"t2"}) == {
+        (MainState.get_full_name(), "rendered_var")
+    }
+    assert ms._dirty_computed_vars(from_vars={"t1"}) == {
+        (MainState.get_full_name(), "rendered_var")
+    }
     assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
     assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
-        "flag",
-        "t1",
-        "t2",
+        MainState.get_full_name(): {"flag", "t1", "t2"}
     }
     }
 
 
 
 
@@ -1370,7 +1375,10 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
         assert isinstance(HandlerState.handler, EventHandler)
         assert isinstance(HandlerState.handler, EventHandler)
 
 
     s = HandlerState()
     s = HandlerState()
-    assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
+    assert (
+        HandlerState.get_full_name(),
+        "cached_x_side_effect",
+    ) in s._var_dependencies["x"]
     assert s.cached_x_side_effect == 1
     assert s.cached_x_side_effect == 1
     assert s.x == 43
     assert s.x == 43
     s.handler()
     s.handler()
@@ -1460,15 +1468,15 @@ def test_computed_var_dependencies():
             return [z in self._z for z in range(5)]
             return [z in self._z for z in range(5)]
 
 
     cs = ComputedState()
     cs = ComputedState()
-    assert cs._computed_var_dependencies["v"] == {
-        "comp_v",
-        "comp_v_backend",
-        "comp_v_via_property",
+    assert cs._var_dependencies["v"] == {
+        (ComputedState.get_full_name(), "comp_v"),
+        (ComputedState.get_full_name(), "comp_v_backend"),
+        (ComputedState.get_full_name(), "comp_v_via_property"),
     }
     }
-    assert cs._computed_var_dependencies["w"] == {"comp_w"}
-    assert cs._computed_var_dependencies["x"] == {"comp_x"}
-    assert cs._computed_var_dependencies["y"] == {"comp_y"}
-    assert cs._computed_var_dependencies["_z"] == {"comp_z"}
+    assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")}
+    assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")}
+    assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")}
+    assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")}
 
 
 
 
 def test_backend_method():
 def test_backend_method():
@@ -3180,7 +3188,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
 RxState = State
 RxState = State
 
 
 
 
-def test_potentially_dirty_substates():
+def test_potentially_dirty_states():
     """Test that potentially_dirty_substates returns the correct substates.
     """Test that potentially_dirty_substates returns the correct substates.
 
 
     Even if the name "State" is shadowed, it should still work correctly.
     Even if the name "State" is shadowed, it should still work correctly.
@@ -3196,13 +3204,19 @@ def test_potentially_dirty_substates():
         def bar(self) -> str:
         def bar(self) -> str:
             return ""
             return ""
 
 
-    assert RxState._potentially_dirty_substates() == set()
-    assert State._potentially_dirty_substates() == set()
-    assert C1._potentially_dirty_substates() == set()
+    assert RxState._get_potentially_dirty_states() == set()
+    assert State._get_potentially_dirty_states() == set()
+    assert C1._get_potentially_dirty_states() == set()
+
 
 
+@pytest.mark.asyncio
+async def test_router_var_dep(state_manager: StateManager, token: str) -> None:
+    """Test that router var dependencies are correctly tracked.
 
 
-def test_router_var_dep() -> None:
-    """Test that router var dependencies are correctly tracked."""
+    Args:
+        state_manager: A state manager.
+        token: A token.
+    """
 
 
     class RouterVarParentState(State):
     class RouterVarParentState(State):
         """A parent state for testing router var dependency."""
         """A parent state for testing router var dependency."""
@@ -3219,30 +3233,27 @@ def test_router_var_dep() -> None:
     foo = RouterVarDepState.computed_vars["foo"]
     foo = RouterVarDepState.computed_vars["foo"]
     State._init_var_dependency_dicts()
     State._init_var_dependency_dicts()
 
 
-    assert foo._deps(objclass=RouterVarDepState) == {"router"}
-    assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
-    assert RouterVarParentState._substate_var_dependencies == {
-        "router": {RouterVarDepState.get_name()}
-    }
-    assert RouterVarDepState._computed_var_dependencies == {
-        "router": {"foo"},
+    assert foo._deps(objclass=RouterVarDepState) == {
+        RouterVarDepState.get_full_name(): {"router"}
     }
     }
+    assert (RouterVarDepState.get_full_name(), "foo") in State._var_dependencies[
+        "router"
+    ]
 
 
-    rx_state = State()
-    parent_state = RouterVarParentState()
-    state = RouterVarDepState()
-
-    # link states
-    rx_state.substates = {RouterVarParentState.get_name(): parent_state}
-    parent_state.parent_state = rx_state
-    state.parent_state = parent_state
-    parent_state.substates = {RouterVarDepState.get_name(): state}
+    # Get state from state manager.
+    state_manager.state = State
+    rx_state = await state_manager.get_state(_substate_key(token, State))
+    assert RouterVarParentState.get_name() in rx_state.substates
+    parent_state = rx_state.substates[RouterVarParentState.get_name()]
+    assert RouterVarDepState.get_name() in parent_state.substates
+    state = parent_state.substates[RouterVarDepState.get_name()]
 
 
     assert state.dirty_vars == set()
     assert state.dirty_vars == set()
 
 
     # Reassign router var
     # Reassign router var
     state.router = state.router
     state.router = state.router
-    assert state.dirty_vars == {"foo", "router"}
+    assert rx_state.dirty_vars == {"router"}
+    assert state.dirty_vars == {"foo"}
     assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
     assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
 
 
 
 
@@ -3801,3 +3812,128 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str):
     # Generic Var with no state
     # Generic Var with no state
     with pytest.raises(UnretrievableVarValueError):
     with pytest.raises(UnretrievableVarValueError):
         await state.get_var_value(rx.Var("undefined"))
         await state.get_var_value(rx.Var("undefined"))
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
+    """A test where an async computed var depends on a var in another state.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+
+    class Parent(BaseState):
+        """A root state like rx.State."""
+
+        parent_var: int = 0
+
+    class Child2(Parent):
+        """An unconnected child state."""
+
+        pass
+
+    class Child3(Parent):
+        """A child state with a computed var causing it to be pre-fetched.
+
+        If child3_var gets set to a value, and `get_state` erroneously
+        re-fetches it from redis, the value will be lost.
+        """
+
+        child3_var: int = 0
+
+        @rx.var(cache=True)
+        def v(self) -> int:
+            return self.child3_var
+
+    class Child(Parent):
+        """A state simulating UpdateVarsInternalState."""
+
+        @rx.var(cache=True)
+        async def v(self) -> int:
+            p = await self.get_state(Parent)
+            child3 = await self.get_state(Child3)
+            return child3.child3_var + p.parent_var
+
+    mock_app.state_manager.state = mock_app._state = Parent
+
+    # Get the top level state via unconnected sibling.
+    root = await mock_app.state_manager.get_state(_substate_key(token, Child))
+    # Set value in parent_var to assert it does not get refetched later.
+    root.parent_var = 1
+
+    if isinstance(mock_app.state_manager, StateManagerRedis):
+        # When redis is used, only states with uncached computed vars are pre-fetched.
+        assert Child2.get_name() not in root.substates
+        assert Child3.get_name() not in root.substates
+
+    # Get the unconnected sibling state, which will be used to `get_state` other instances.
+    child = root.get_substate(Child.get_full_name().split("."))
+
+    # Get an uncached child state.
+    child2 = await child.get_state(Child2)
+    assert child2.parent_var == 1
+
+    # Set value on already-cached Child3 state (prefetched because it has a Computed Var).
+    child3 = await child.get_state(Child3)
+    child3.child3_var = 1
+
+    assert await child.v == 2
+    assert await child.v == 2
+    root.parent_var = 2
+    assert await child.v == 3
+
+
+class Table(rx.ComponentState):
+    """A table state."""
+
+    data: ClassVar[Var]
+
+    @rx.var(cache=True, auto_deps=False)
+    async def rows(self) -> List[Dict[str, Any]]:
+        """Computed var over the given rows.
+
+        Returns:
+            The data rows.
+        """
+        return await self.get_var_value(self.data)
+
+    @classmethod
+    def get_component(cls, data: Var) -> rx.Component:
+        """Get the component for the table.
+
+        Args:
+            data: The data var.
+
+        Returns:
+            The component.
+        """
+        cls.data = data
+        cls.computed_vars["rows"].add_dependency(cls, data)
+        return rx.foreach(data, lambda d: rx.text(d.to_string()))
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str):
+    """A test where an async computed var depends on a var in another state.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+
+    class OtherState(rx.State):
+        """A state with a var."""
+
+        data: List[Dict[str, Any]] = [{"foo": "bar"}]
+
+    mock_app.state_manager.state = mock_app._state = rx.State
+    comp = Table.create(data=OtherState.data)
+    state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
+    other_state = await state.get_state(OtherState)
+    assert comp.State is not None
+    comp_state = await state.get_state(comp.State)
+    assert comp_state.dirty_vars == set()
+
+    other_state.data.append({"foo": "baz"})
+    assert "rows" in comp_state.dirty_vars

+ 25 - 3
tests/units/test_var.py

@@ -1807,9 +1807,9 @@ def cv_fget(state: BaseState) -> int:
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "deps,expected",
     "deps,expected",
     [
     [
-        (["a"], {"a"}),
-        (["b"], {"b"}),
-        ([ComputedVar(fget=cv_fget)], {"cv_fget"}),
+        (["a"], {None: {"a"}}),
+        (["b"], {None: {"b"}}),
+        ([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}),
     ],
     ],
 )
 )
 def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
 def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
@@ -1857,6 +1857,28 @@ def test_to_string_operation():
     assert single_var._var_type == Email
     assert single_var._var_type == Email
 
 
 
 
+@pytest.mark.asyncio
+async def test_async_computed_var():
+    side_effect_counter = 0
+
+    class AsyncComputedVarState(BaseState):
+        v: int = 1
+
+        @computed_var(cache=True)
+        async def async_computed_var(self) -> int:
+            nonlocal side_effect_counter
+            side_effect_counter += 1
+            return self.v + 1
+
+    my_state = AsyncComputedVarState()
+    assert await my_state.async_computed_var == 2
+    assert await my_state.async_computed_var == 2
+    my_state.v = 2
+    assert await my_state.async_computed_var == 3
+    assert await my_state.async_computed_var == 3
+    assert side_effect_counter == 2
+
+
 def test_var_data_hooks():
 def test_var_data_hooks():
     var_data_str = VarData(hooks="what")
     var_data_str = VarData(hooks="what")
     var_data_list = VarData(hooks=["what"])
     var_data_list = VarData(hooks=["what"])