Преглед на файлове

[ENG-4326] Async ComputedVar (#4711)

* WiP

* Save the var from get_var_name

* flatten StateManagerRedis.get_state algorithm

simplify fetching of states and avoid repeatedly fetching the same state

* Get all the states in a single redis round-trip

* update docstrings in StateManagerRedis

* Move computed var dep tracking to separate module

* Fix pre-commit issues

* ComputedVar.add_dependency: explicitly dependency declaration

Allow var dependencies to be added at runtime, for example, when defining a
ComponentState that depends on vars that cannot be known statically.

Fix more pyright issues.

* Fix/ignore more pyright issues from recent merge

* handle cleaning out _potentially_dirty_states on reload

* ignore accessed attributes missing on state class

these might be added dynamically later in which case we recompute the
dependency tracking dicts... if not, they'll blow up anyway at runtime.

* fix playwright tests, which insist on running an asyncio loop

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
Masen Furer преди 3 месеца
родител
ревизия
a2243190ff

+ 11 - 5
reflex/app.py

@@ -908,11 +908,17 @@ class App(MiddlewareMixin, LifespanMixin):
             if not var._cache:
                 continue
             deps = var._deps(objclass=state)
-            for dep in deps:
-                if dep not in state.vars and dep not in state.backend_vars:
-                    raise exceptions.VarDependencyError(
-                        f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}"
-                    )
+            for state_name, dep_set in deps.items():
+                state_cls = (
+                    state.get_root_state().get_class_substate(state_name)
+                    if state_name != state.get_full_name()
+                    else state
+                )
+                for dep in dep_set:
+                    if dep not in state_cls.vars and dep not in state_cls.backend_vars:
+                        raise exceptions.VarDependencyError(
+                            f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}"
+                        )
 
         for substate in state.class_subclasses:
             self._validate_var_dependencies(substate)

+ 22 - 2
reflex/compiler/utils.py

@@ -2,12 +2,15 @@
 
 from __future__ import annotations
 
+import asyncio
+import concurrent.futures
 import traceback
 from datetime import datetime
 from pathlib import Path
 from typing import Any, Callable, Dict, Optional, Type, Union
 from urllib.parse import urlparse
 
+from reflex.utils.exec import is_in_app_harness
 from reflex.utils.prerequisites import get_web_dir
 from reflex.vars.base import Var
 
@@ -33,7 +36,7 @@ from reflex.components.base import (
 )
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
-from reflex.state import BaseState
+from reflex.state import BaseState, _resolve_delta
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils.imports import ImportVar, ParsedImportDict
@@ -177,7 +180,24 @@ def compile_state(state: Type[BaseState]) -> dict:
         initial_state = state(_reflex_internal_init=True).dict(
             initial=True, include_computed=False
         )
-    return initial_state
+    try:
+        _ = asyncio.get_running_loop()
+    except RuntimeError:
+        pass
+    else:
+        if is_in_app_harness():
+            # Playwright tests already have an event loop running, so we can't use asyncio.run.
+            with concurrent.futures.ThreadPoolExecutor() as pool:
+                resolved_initial_state = pool.submit(
+                    asyncio.run, _resolve_delta(initial_state)
+                ).result()
+                console.warn(
+                    f"Had to get initial state in a thread 🤮 {resolved_initial_state}",
+                )
+                return resolved_initial_state
+
+    # Normally the compile runs before any event loop starts, we asyncio.run is available for calling.
+    return asyncio.run(_resolve_delta(initial_state))
 
 
 def _compile_client_storage_field(

+ 2 - 2
reflex/middleware/hydrate_middleware.py

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
 from reflex import constants
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware.middleware import Middleware
-from reflex.state import BaseState, StateUpdate
+from reflex.state import BaseState, StateUpdate, _resolve_delta
 
 if TYPE_CHECKING:
     from reflex.app import App
@@ -42,7 +42,7 @@ class HydrateMiddleware(Middleware):
         setattr(state, constants.CompileVars.IS_HYDRATED, False)
 
         # Get the initial state.
-        delta = state.dict()
+        delta = await _resolve_delta(state.dict())
         # since a full dict was captured, clean any dirtiness
         state._clean()
 

+ 253 - 323
reflex/state.py

@@ -15,7 +15,6 @@ import time
 import typing
 import uuid
 from abc import ABC, abstractmethod
-from collections import defaultdict
 from hashlib import md5
 from pathlib import Path
 from types import FunctionType, MethodType
@@ -329,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
     )
 
 
+async def _resolve_delta(delta: Delta) -> Delta:
+    """Await all coroutines in the delta.
+
+    Args:
+        delta: The delta to process.
+
+    Returns:
+        The same delta dict with all coroutines resolved to their return value.
+    """
+    tasks = {}
+    for state_name, state_delta in delta.items():
+        for var_name, value in state_delta.items():
+            if asyncio.iscoroutine(value):
+                tasks[state_name, var_name] = asyncio.create_task(value)
+    for (state_name, var_name), task in tasks.items():
+        delta[state_name][var_name] = await task
+    return delta
+
+
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
 
@@ -356,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # A set of subclassses of this class.
     class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
 
-    # Mapping of var name to set of computed variables that depend on it
-    _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
-
-    # Mapping of var name to set of substates that depend on it
-    _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
+    # Mapping of var name to set of (state_full_name, var_name) that depend on it.
+    _var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {}
 
     # Set of vars which always need to be recomputed
     _always_dirty_computed_vars: ClassVar[Set[str]] = set()
@@ -368,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # Set of substates which always need to be recomputed
     _always_dirty_substates: ClassVar[Set[str]] = set()
 
+    # Set of states which might need to be recomputed if vars in this state change.
+    _potentially_dirty_states: ClassVar[Set[str]] = set()
+
     # The parent state.
     parent_state: Optional[BaseState] = None
 
@@ -519,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         # Reset dirty substate tracking for this class.
         cls._always_dirty_substates = set()
+        cls._potentially_dirty_states = set()
 
         # Get the parent vars.
         parent_state = cls.get_parent_state()
@@ -622,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             setattr(cls, name, handler)
 
         # Initialize per-class var dependency tracking.
-        cls._computed_var_dependencies = defaultdict(set)
-        cls._substate_var_dependencies = defaultdict(set)
+        cls._var_dependencies = {}
         cls._init_var_dependency_dicts()
 
     @staticmethod
@@ -768,26 +786,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         Additional updates tracking dicts for vars and substates that always
         need to be recomputed.
         """
-        inherited_vars = set(cls.inherited_vars).union(
-            set(cls.inherited_backend_vars),
-        )
         for cvar_name, cvar in cls.computed_vars.items():
-            # Add the dependencies.
-            for var in cvar._deps(objclass=cls):
-                cls._computed_var_dependencies[var].add(cvar_name)
-                if var in inherited_vars:
-                    # track that this substate depends on its parent for this var
-                    state_name = cls.get_name()
-                    parent_state = cls.get_parent_state()
-                    while parent_state is not None and var in {
-                        **parent_state.vars,
-                        **parent_state.backend_vars,
+            if not cvar._cache:
+                # Do not perform dep calculation when cache=False (these are always dirty).
+                continue
+            for state_name, dvar_set in cvar._deps(objclass=cls).items():
+                state_cls = cls.get_root_state().get_class_substate(state_name)
+                for dvar in dvar_set:
+                    defining_state_cls = state_cls
+                    while dvar in {
+                        *defining_state_cls.inherited_vars,
+                        *defining_state_cls.inherited_backend_vars,
                     }:
-                        parent_state._substate_var_dependencies[var].add(state_name)
-                        state_name, parent_state = (
-                            parent_state.get_name(),
-                            parent_state.get_parent_state(),
-                        )
+                        parent_state = defining_state_cls.get_parent_state()
+                        if parent_state is not None:
+                            defining_state_cls = parent_state
+                    defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
+                        (cls.get_full_name(), cvar_name)
+                    )
+                    defining_state_cls._potentially_dirty_states.add(
+                        cls.get_full_name()
+                    )
 
         # ComputedVar with cache=False always need to be recomputed
         cls._always_dirty_computed_vars = {
@@ -902,6 +921,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             raise ValueError(f"Only one parent state is allowed {parent_states}.")
         return parent_states[0] if len(parent_states) == 1 else None
 
+    @classmethod
+    @functools.lru_cache()
+    def get_root_state(cls) -> Type[BaseState]:
+        """Get the root state.
+
+        Returns:
+            The root state.
+        """
+        parent_state = cls.get_parent_state()
+        return cls if parent_state is None else parent_state.get_root_state()
+
     @classmethod
     def get_substates(cls) -> set[Type[BaseState]]:
         """Get the substates of the state.
@@ -1351,7 +1381,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         super().__setattr__(name, value)
 
         # Add the var to the dirty list.
-        if name in self.vars or name in self._computed_var_dependencies:
+        if name in self.base_vars:
             self.dirty_vars.add(name)
             self._mark_dirty()
 
@@ -1422,64 +1452,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return self.substates[path[0]].get_substate(path[1:])
 
     @classmethod
-    def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
-        """Find the name of the nearest common ancestor shared by this and the other state.
-
-        Args:
-            other: The other state.
-
-        Returns:
-            Full name of the nearest common ancestor.
-        """
-        common_ancestor_parts = []
-        for part1, part2 in zip(
-            cls.get_full_name().split("."),
-            other.get_full_name().split("."),
-            strict=True,
-        ):
-            if part1 != part2:
-                break
-            common_ancestor_parts.append(part1)
-        return ".".join(common_ancestor_parts)
-
-    @classmethod
-    def _determine_missing_parent_states(
-        cls, target_state_cls: Type[BaseState]
-    ) -> tuple[str, list[str]]:
-        """Determine the missing parent states between the target_state_cls and common ancestor of this state.
-
-        Args:
-            target_state_cls: The class of the state to find missing parent states for.
-
-        Returns:
-            The name of the common ancestor and the list of missing parent states.
-        """
-        common_ancestor_name = cls._get_common_ancestor(target_state_cls)
-        common_ancestor_parts = common_ancestor_name.split(".")
-        target_state_parts = tuple(target_state_cls.get_full_name().split("."))
-        relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
-
-        # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
-        fetch_parent_states = [common_ancestor_name]
-        for relative_parent_state_name in relative_target_state_parts:
-            fetch_parent_states.append(
-                ".".join((fetch_parent_states[-1], relative_parent_state_name))
-            )
-
-        return common_ancestor_name, fetch_parent_states[1:-1]
-
-    def _get_parent_states(self) -> list[tuple[str, BaseState]]:
-        """Get all parent state instances up to the root of the state tree.
+    def _get_potentially_dirty_states(cls) -> set[type[BaseState]]:
+        """Get substates which may have dirty vars due to dependencies.
 
         Returns:
-            A list of tuples containing the name and the instance of each parent state.
+            The set of potentially dirty substate classes.
         """
-        parent_states_with_name = []
-        parent_state = self
-        while parent_state.parent_state is not None:
-            parent_state = parent_state.parent_state
-            parent_states_with_name.append((parent_state.get_full_name(), parent_state))
-        return parent_states_with_name
+        return {
+            cls.get_class_substate(substate_name)
+            for substate_name in cls._always_dirty_substates
+        }.union(
+            {
+                cls.get_root_state().get_class_substate(substate_name)
+                for substate_name in cls._potentially_dirty_states
+            }
+        )
 
     def _get_root_state(self) -> BaseState:
         """Get the root state of the state tree.
@@ -1492,55 +1479,38 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             parent_state = parent_state.parent_state
         return parent_state
 
-    async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
-        """Populate substates in the tree between the target_state_cls and common ancestor of this state.
+    async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
+        """Get a state instance from redis.
 
         Args:
-            target_state_cls: The class of the state to populate parent states for.
+            state_cls: The class of the state.
 
         Returns:
-            The parent state instance of target_state_cls.
+            The instance of state_cls associated with this state's client_token.
 
         Raises:
             RuntimeError: If redis is not used in this backend process.
+            StateMismatchError: If the state instance is not of the expected type.
         """
+        # Then get the target state and all its substates.
         state_manager = get_state_manager()
         if not isinstance(state_manager, StateManagerRedis):
             raise RuntimeError(
-                f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
+                f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
                 "(All states should already be available -- this is likely a bug).",
             )
+        state_in_redis = await state_manager.get_state(
+            token=_substate_key(self.router.session.client_token, state_cls),
+            top_level=False,
+            for_state_instance=self,
+        )
 
-        # Find the missing parent states up to the common ancestor.
-        (
-            common_ancestor_name,
-            missing_parent_states,
-        ) = self._determine_missing_parent_states(target_state_cls)
-
-        # Fetch all missing parent states and link them up to the common ancestor.
-        parent_states_tuple = self._get_parent_states()
-        root_state = parent_states_tuple[-1][1]
-        parent_states_by_name = dict(parent_states_tuple)
-        parent_state = parent_states_by_name[common_ancestor_name]
-        for parent_state_name in missing_parent_states:
-            try:
-                parent_state = root_state.get_substate(parent_state_name.split("."))
-                # The requested state is already cached, do NOT fetch it again.
-                continue
-            except ValueError:
-                # The requested state is missing, fetch from redis.
-                pass
-            parent_state = await state_manager.get_state(
-                token=_substate_key(
-                    self.router.session.client_token, parent_state_name
-                ),
-                top_level=False,
-                get_substates=False,
-                parent_state=parent_state,
+        if not isinstance(state_in_redis, state_cls):
+            raise StateMismatchError(
+                f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
             )
 
-        # Return the direct parent of target_state_cls for subsequent linking.
-        return parent_state
+        return state_in_redis
 
     def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get a state instance from the cache.
@@ -1562,44 +1532,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             )
         return substate
 
-    async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
-        """Get a state instance from redis.
-
-        Args:
-            state_cls: The class of the state.
-
-        Returns:
-            The instance of state_cls associated with this state's client_token.
-
-        Raises:
-            RuntimeError: If redis is not used in this backend process.
-            StateMismatchError: If the state instance is not of the expected type.
-        """
-        # Fetch all missing parent states from redis.
-        parent_state_of_state_cls = await self._populate_parent_states(state_cls)
-
-        # Then get the target state and all its substates.
-        state_manager = get_state_manager()
-        if not isinstance(state_manager, StateManagerRedis):
-            raise RuntimeError(
-                f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
-                "(All states should already be available -- this is likely a bug).",
-            )
-
-        state_in_redis = await state_manager.get_state(
-            token=_substate_key(self.router.session.client_token, state_cls),
-            top_level=False,
-            get_substates=True,
-            parent_state=parent_state_of_state_cls,
-        )
-
-        if not isinstance(state_in_redis, state_cls):
-            raise StateMismatchError(
-                f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
-            )
-
-        return state_in_redis
-
     async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get an instance of the state associated with this token.
 
@@ -1738,7 +1670,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
         )
 
-    def _as_state_update(
+    async def _as_state_update(
         self,
         handler: EventHandler,
         events: EventSpec | list[EventSpec] | None,
@@ -1766,7 +1698,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         try:
             # Get the delta after processing the event.
-            delta = state.get_delta()
+            delta = await _resolve_delta(state.get_delta())
             state._clean()
 
             return StateUpdate(
@@ -1866,24 +1798,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             # Handle async generators.
             if inspect.isasyncgen(events):
                 async for event in events:
-                    yield state._as_state_update(handler, event, final=False)
-                yield state._as_state_update(handler, events=None, final=True)
+                    yield await state._as_state_update(handler, event, final=False)
+                yield await state._as_state_update(handler, events=None, final=True)
 
             # Handle regular generators.
             elif inspect.isgenerator(events):
                 try:
                     while True:
-                        yield state._as_state_update(handler, next(events), final=False)
+                        yield await state._as_state_update(
+                            handler, next(events), final=False
+                        )
                 except StopIteration as si:
                     # the "return" value of the generator is not available
                     # in the loop, we must catch StopIteration to access it
                     if si.value is not None:
-                        yield state._as_state_update(handler, si.value, final=False)
-                yield state._as_state_update(handler, events=None, final=True)
+                        yield await state._as_state_update(
+                            handler, si.value, final=False
+                        )
+                yield await state._as_state_update(handler, events=None, final=True)
 
             # Handle regular event chains.
             else:
-                yield state._as_state_update(handler, events, final=True)
+                yield await state._as_state_update(handler, events, final=True)
 
         # If an error occurs, throw a window alert.
         except Exception as ex:
@@ -1893,7 +1829,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
             )
 
-            yield state._as_state_update(
+            yield await state._as_state_update(
                 handler,
                 event_specs,
                 final=True,
@@ -1901,15 +1837,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
     def _mark_dirty_computed_vars(self) -> None:
         """Mark ComputedVars that need to be recalculated based on dirty_vars."""
+        # Append expired computed vars to dirty_vars to trigger recalculation
+        self.dirty_vars.update(self._expired_computed_vars())
+        # Append always dirty computed vars to dirty_vars to trigger recalculation
+        self.dirty_vars.update(self._always_dirty_computed_vars)
+
         dirty_vars = self.dirty_vars
         while dirty_vars:
             calc_vars, dirty_vars = dirty_vars, set()
-            for cvar in self._dirty_computed_vars(from_vars=calc_vars):
-                self.dirty_vars.add(cvar)
+            for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars):
+                if state_name == self.get_full_name():
+                    defining_state = self
+                else:
+                    defining_state = self._get_root_state().get_substate(
+                        tuple(state_name.split("."))
+                    )
+                defining_state.dirty_vars.add(cvar)
                 dirty_vars.add(cvar)
-                actual_var = self.computed_vars.get(cvar)
+                actual_var = defining_state.computed_vars.get(cvar)
                 if actual_var is not None:
-                    actual_var.mark_dirty(instance=self)
+                    actual_var.mark_dirty(instance=defining_state)
+                if defining_state is not self:
+                    defining_state._mark_dirty()
 
     def _expired_computed_vars(self) -> set[str]:
         """Determine ComputedVars that need to be recalculated based on the expiration time.
@@ -1925,7 +1874,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
     def _dirty_computed_vars(
         self, from_vars: set[str] | None = None, include_backend: bool = True
-    ) -> set[str]:
+    ) -> set[tuple[str, str]]:
         """Determine ComputedVars that need to be recalculated based on the given vars.
 
         Args:
@@ -1936,33 +1885,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             Set of computed vars to include in the delta.
         """
         return {
-            cvar
+            (state_name, cvar)
             for dirty_var in from_vars or self.dirty_vars
-            for cvar in self._computed_var_dependencies[dirty_var]
+            for state_name, cvar in self._var_dependencies.get(dirty_var, set())
             if include_backend or not self.computed_vars[cvar]._backend
         }
 
-    @classmethod
-    def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
-        """Determine substates which could be affected by dirty vars in this state.
-
-        Returns:
-            Set of State classes that may need to be fetched to recalc computed vars.
-        """
-        # _always_dirty_substates need to be fetched to recalc computed vars.
-        fetch_substates = {
-            cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
-            for substate_name in cls._always_dirty_substates
-        }
-        for dependent_substates in cls._substate_var_dependencies.values():
-            fetch_substates.update(
-                {
-                    cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
-                    for substate_name in dependent_substates
-                }
-            )
-        return fetch_substates
-
     def get_delta(self) -> Delta:
         """Get the delta for the state.
 
@@ -1971,21 +1899,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         delta = {}
 
-        # Apply dirty variables down into substates
-        self.dirty_vars.update(self._always_dirty_computed_vars)
-        self._mark_dirty()
-
+        self._mark_dirty_computed_vars()
         frontend_computed_vars: set[str] = {
             name for name, cv in self.computed_vars.items() if not cv._backend
         }
 
         # Return the dirty vars for this instance, any cached/dependent computed vars,
         # and always dirty computed vars (cache=False)
-        delta_vars = (
-            self.dirty_vars.intersection(self.base_vars)
-            .union(self.dirty_vars.intersection(frontend_computed_vars))
-            .union(self._dirty_computed_vars(include_backend=False))
-            .union(self._always_dirty_computed_vars)
+        delta_vars = self.dirty_vars.intersection(self.base_vars).union(
+            self.dirty_vars.intersection(frontend_computed_vars)
         )
 
         subdelta: Dict[str, Any] = {
@@ -2015,23 +1937,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             self.parent_state.dirty_substates.add(self.get_name())
             self.parent_state._mark_dirty()
 
-        # Append expired computed vars to dirty_vars to trigger recalculation
-        self.dirty_vars.update(self._expired_computed_vars())
-
         # have to mark computed vars dirty to allow access to newly computed
         # values within the same ComputedVar function
         self._mark_dirty_computed_vars()
-        self._mark_dirty_substates()
-
-    def _mark_dirty_substates(self):
-        """Propagate dirty var / computed var status into substates."""
-        substates = self.substates
-        for var in self.dirty_vars:
-            for substate_name in self._substate_var_dependencies[var]:
-                self.dirty_substates.add(substate_name)
-                substate = substates[substate_name]
-                substate.dirty_vars.add(var)
-                substate._mark_dirty()
 
     def _update_was_touched(self):
         """Update the _was_touched flag based on dirty_vars."""
@@ -2103,11 +2011,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             The object as a dictionary.
         """
         if include_computed:
-            # Apply dirty variables down into substates to allow never-cached ComputedVar to
-            # trigger recalculation of dependent vars
-            self.dirty_vars.update(self._always_dirty_computed_vars)
-            self._mark_dirty()
-
+            self._mark_dirty_computed_vars()
         base_vars = {
             prop_name: self.get_value(prop_name) for prop_name in self.base_vars
         }
@@ -2824,7 +2728,7 @@ class StateProxy(wrapt.ObjectProxy):
             await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
         )
 
-    def _as_state_update(self, *args, **kwargs) -> StateUpdate:
+    async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
         """Temporarily allow mutability to access parent_state.
 
         Args:
@@ -2837,7 +2741,7 @@ class StateProxy(wrapt.ObjectProxy):
         original_mutable = self._self_mutable
         self._self_mutable = True
         try:
-            return self.__wrapped__._as_state_update(*args, **kwargs)
+            return await self.__wrapped__._as_state_update(*args, **kwargs)
         finally:
             self._self_mutable = original_mutable
 
@@ -3313,103 +3217,106 @@ class StateManagerRedis(StateManager):
         b"evicted",
     }
 
-    async def _get_parent_state(
-        self, token: str, state: BaseState | None = None
-    ) -> BaseState | None:
-        """Get the parent state for the state requested in the token.
+    def _get_required_state_classes(
+        self,
+        target_state_cls: Type[BaseState],
+        subclasses: bool = False,
+        required_state_classes: set[Type[BaseState]] | None = None,
+    ) -> set[Type[BaseState]]:
+        """Recursively determine which states are required to fetch the target state.
+
+        This will always include potentially dirty substates that depend on vars
+        in the target_state_cls.
 
         Args:
-            token: The token to get the state for (_substate_key).
-            state: The state instance to get parent state for.
+            target_state_cls: The target state class being fetched.
+            subclasses: Whether to include subclasses of the target state.
+            required_state_classes: Recursive argument tracking state classes that have already been seen.
 
         Returns:
-            The parent state for the state requested by the token or None if there is no such parent.
-        """
-        parent_state = None
-        client_token, state_path = _split_substate_key(token)
-        parent_state_name = state_path.rpartition(".")[0]
-        if parent_state_name:
-            cached_substates = None
-            if state is not None:
-                cached_substates = [state]
-            # Retrieve the parent state to populate event handlers onto this substate.
-            parent_state = await self.get_state(
-                token=_substate_key(client_token, parent_state_name),
-                top_level=False,
-                get_substates=False,
-                cached_substates=cached_substates,
+            The set of state classes required to fetch the target state.
+        """
+        if required_state_classes is None:
+            required_state_classes = set()
+        # Get the substates if requested.
+        if subclasses:
+            for substate in target_state_cls.get_substates():
+                self._get_required_state_classes(
+                    substate,
+                    subclasses=True,
+                    required_state_classes=required_state_classes,
+                )
+        if target_state_cls in required_state_classes:
+            return required_state_classes
+        required_state_classes.add(target_state_cls)
+
+        # Get dependent substates.
+        for pd_substates in target_state_cls._get_potentially_dirty_states():
+            self._get_required_state_classes(
+                pd_substates,
+                subclasses=False,
+                required_state_classes=required_state_classes,
             )
-        return parent_state
 
-    async def _populate_substates(
-        self,
-        token: str,
-        state: BaseState,
-        all_substates: bool = False,
-    ):
-        """Fetch and link substates for the given state instance.
+        # Get the parent state if it exists.
+        if parent_state := target_state_cls.get_parent_state():
+            self._get_required_state_classes(
+                parent_state,
+                subclasses=False,
+                required_state_classes=required_state_classes,
+            )
+        return required_state_classes
 
-        There is no return value; the side-effect is that `state` will have `substates` populated,
-        and each substate will have its `parent_state` set to `state`.
+    def _get_populated_states(
+        self,
+        target_state: BaseState,
+        populated_states: dict[str, BaseState] | None = None,
+    ) -> dict[str, BaseState]:
+        """Recursively determine which states from target_state are already fetched.
 
         Args:
-            token: The token to get the state for.
-            state: The state instance to populate substates for.
-            all_substates: Whether to fetch all substates or just required substates.
-        """
-        client_token, _ = _split_substate_key(token)
+            target_state: The state to check for populated states.
+            populated_states: Recursive argument tracking states seen in previous calls.
 
-        if all_substates:
-            # All substates are requested.
-            fetch_substates = state.get_substates()
-        else:
-            # Only _potentially_dirty_substates need to be fetched to recalc computed vars.
-            fetch_substates = state._potentially_dirty_substates()
-
-        tasks = {}
-        # Retrieve the necessary substates from redis.
-        for substate_cls in fetch_substates:
-            if substate_cls.get_name() in state.substates:
-                continue
-            substate_name = substate_cls.get_name()
-            tasks[substate_name] = asyncio.create_task(
-                self.get_state(
-                    token=_substate_key(client_token, substate_cls),
-                    top_level=False,
-                    get_substates=all_substates,
-                    parent_state=state,
-                )
+        Returns:
+            A dictionary of state full name to state instance.
+        """
+        if populated_states is None:
+            populated_states = {}
+        if target_state.get_full_name() in populated_states:
+            return populated_states
+        populated_states[target_state.get_full_name()] = target_state
+        for substate in target_state.substates.values():
+            self._get_populated_states(substate, populated_states=populated_states)
+        if target_state.parent_state is not None:
+            self._get_populated_states(
+                target_state.parent_state, populated_states=populated_states
             )
-
-        for substate_name, substate_task in tasks.items():
-            state.substates[substate_name] = await substate_task
+        return populated_states
 
     @override
     async def get_state(
         self,
         token: str,
         top_level: bool = True,
-        get_substates: bool = True,
-        parent_state: BaseState | None = None,
-        cached_substates: list[BaseState] | None = None,
+        for_state_instance: BaseState | None = None,
     ) -> BaseState:
         """Get the state for a token.
 
         Args:
             token: The token to get the state for.
             top_level: If true, return an instance of the top-level state (self.state).
-            get_substates: If true, also retrieve substates.
-            parent_state: If provided, use this parent_state instead of getting it from redis.
-            cached_substates: If provided, attach these substates to the state.
+            for_state_instance: If provided, attach the requested states to this existing state tree.
 
         Returns:
             The state for the token.
 
         Raises:
-            RuntimeError: when the state_cls is not specified in the token
+            RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
+                requested state was not fetched.
         """
         # Split the actual token from the fully qualified substate name.
-        _, state_path = _split_substate_key(token)
+        token, state_path = _split_substate_key(token)
         if state_path:
             # Get the State class associated with the given path.
             state_cls = self.state.get_class_substate(state_path)
@@ -3418,43 +3325,59 @@ class StateManagerRedis(StateManager):
                 f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
             )
 
-        # The deserialized or newly created (sub)state instance.
-        state = None
-
-        # Fetch the serialized substate from redis.
-        redis_state = await self.redis.get(token)
-
-        if redis_state is not None:
-            # Deserialize the substate.
-            with contextlib.suppress(StateSchemaMismatchError):
-                state = BaseState._deserialize(data=redis_state)
-        if state is None:
-            # Key didn't exist or schema mismatch so create a new instance for this token.
-            state = state_cls(
-                init_substates=False,
-                _reflex_internal_init=True,
-            )
-        # Populate parent state if missing and requested.
-        if parent_state is None:
-            parent_state = await self._get_parent_state(token, state)
-        # Set up Bidirectional linkage between this state and its parent.
-        if parent_state is not None:
-            parent_state.substates[state.get_name()] = state
-            state.parent_state = parent_state
-        # Avoid fetching substates multiple times.
-        if cached_substates:
-            for substate in cached_substates:
-                state.substates[substate.get_name()] = substate
-                if substate.parent_state is None:
-                    substate.parent_state = state
-        # Populate substates if requested.
-        await self._populate_substates(token, state, all_substates=get_substates)
+        # Determine which states we already have.
+        flat_state_tree: dict[str, BaseState] = (
+            self._get_populated_states(for_state_instance) if for_state_instance else {}
+        )
+
+        # Determine which states from the tree need to be fetched.
+        required_state_classes = sorted(
+            self._get_required_state_classes(state_cls, subclasses=True)
+            - {type(s) for s in flat_state_tree.values()},
+            key=lambda x: x.get_full_name(),
+        )
+
+        redis_pipeline = self.redis.pipeline()
+        for state_cls in required_state_classes:
+            redis_pipeline.get(_substate_key(token, state_cls))
+
+        for state_cls, redis_state in zip(
+            required_state_classes,
+            await redis_pipeline.execute(),
+            strict=False,
+        ):
+            state = None
+
+            if redis_state is not None:
+                # Deserialize the substate.
+                with contextlib.suppress(StateSchemaMismatchError):
+                    state = BaseState._deserialize(data=redis_state)
+            if state is None:
+                # Key didn't exist or schema mismatch so create a new instance for this token.
+                state = state_cls(
+                    init_substates=False,
+                    _reflex_internal_init=True,
+                )
+            flat_state_tree[state.get_full_name()] = state
+            if state.get_parent_state() is not None:
+                parent_state_name, _dot, state_name = state.get_full_name().rpartition(
+                    "."
+                )
+                parent_state = flat_state_tree.get(parent_state_name)
+                if parent_state is None:
+                    raise RuntimeError(
+                        f"Parent state for {state.get_full_name()} was not found "
+                        "in the state tree, but should have already been fetched. "
+                        "This is a bug",
+                    )
+                parent_state.substates[state_name] = state
+                state.parent_state = parent_state
 
         # To retain compatibility with previous implementation, by default, we return
-        # the top-level state by chasing `parent_state` pointers up the tree.
+        # the top-level state which should always be fetched or already cached.
         if top_level:
-            return state._get_root_state()
-        return state
+            return flat_state_tree[self.state.get_full_name()]
+        return flat_state_tree[state_cls.get_full_name()]
 
     @override
     async def set_state(
@@ -4154,12 +4077,19 @@ def reload_state_module(
         state: Recursive argument for the state class to reload.
 
     """
+    # Clean out all potentially dirty states of reloaded modules.
+    for pd_state in tuple(state._potentially_dirty_states):
+        with contextlib.suppress(ValueError):
+            if (
+                state.get_root_state().get_class_substate(pd_state).__module__ == module
+                and module is not None
+            ):
+                state._potentially_dirty_states.remove(pd_state)
     for subclass in tuple(state.class_subclasses):
         reload_state_module(module=module, state=subclass)
         if subclass.__module__ == module and module is not None:
             state.class_subclasses.remove(subclass)
             state._always_dirty_substates.discard(subclass.get_name())
-            state._computed_var_dependencies = defaultdict(set)
-            state._substate_var_dependencies = defaultdict(set)
+            state._var_dependencies = {}
             state._init_var_dependency_dicts()
     state.get_class_substate.cache_clear()

+ 1 - 1
reflex/utils/exec.py

@@ -488,7 +488,7 @@ def output_system_info():
         dependencies.append(fnm_info)
 
     if system == "Linux":
-        import distro
+        import distro  # pyright: ignore[reportMissingImports]
 
         os_version = distro.name(pretty=True)
     else:

+ 238 - 109
reflex/vars/base.py

@@ -5,7 +5,6 @@ from __future__ import annotations
 import contextlib
 import dataclasses
 import datetime
-import dis
 import functools
 import inspect
 import json
@@ -20,6 +19,7 @@ from typing import (
     Any,
     Callable,
     ClassVar,
+    Coroutine,
     Dict,
     FrozenSet,
     Generic,
@@ -51,7 +51,6 @@ from reflex.utils.exceptions import (
     VarAttributeError,
     VarDependencyError,
     VarTypeError,
-    VarValueError,
 )
 from reflex.utils.format import format_state_name
 from reflex.utils.imports import (
@@ -1983,7 +1982,7 @@ class ComputedVar(Var[RETURN_TYPE]):
     _initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
 
     # Explicit var dependencies to track
-    _static_deps: set[str] = dataclasses.field(default_factory=set)
+    _static_deps: dict[str, set[str]] = dataclasses.field(default_factory=dict)
 
     # Whether var dependencies should be auto-determined
     _auto_deps: bool = dataclasses.field(default=True)
@@ -2053,21 +2052,34 @@ class ComputedVar(Var[RETURN_TYPE]):
 
         object.__setattr__(self, "_update_interval", interval)
 
-        if deps is None:
-            deps = []
-        else:
+        _static_deps = {}
+        if isinstance(deps, dict):
+            # Assume a dict is coming from _replace, so no special processing.
+            _static_deps = deps
+        elif deps is not None:
             for dep in deps:
                 if isinstance(dep, Var):
-                    continue
-                if isinstance(dep, str) and dep != "":
-                    continue
-                raise TypeError(
-                    "ComputedVar dependencies must be Var instances or var names (non-empty strings)."
-                )
+                    state_name = (
+                        all_var_data.state
+                        if (all_var_data := dep._get_all_var_data())
+                        and all_var_data.state
+                        else None
+                    )
+                    if all_var_data is not None:
+                        var_name = all_var_data.field_name
+                    else:
+                        var_name = dep._js_expr
+                    _static_deps.setdefault(state_name, set()).add(var_name)
+                elif isinstance(dep, str) and dep != "":
+                    _static_deps.setdefault(None, set()).add(dep)
+                else:
+                    raise TypeError(
+                        "ComputedVar dependencies must be Var instances or var names (non-empty strings)."
+                    )
         object.__setattr__(
             self,
             "_static_deps",
-            {dep._js_expr if isinstance(dep, Var) else dep for dep in deps},
+            _static_deps,
         )
         object.__setattr__(self, "_auto_deps", auto_deps)
 
@@ -2149,6 +2161,13 @@ class ComputedVar(Var[RETURN_TYPE]):
             return True
         return datetime.datetime.now() - last_updated > self._update_interval
 
+    @overload
+    def __get__(
+        self: ComputedVar[bool],
+        instance: None,
+        owner: Type,
+    ) -> BooleanVar: ...
+
     @overload
     def __get__(
         self: ComputedVar[int] | ComputedVar[float],
@@ -2233,125 +2252,67 @@ class ComputedVar(Var[RETURN_TYPE]):
                 setattr(instance, self._last_updated_attr, datetime.datetime.now())
             value = getattr(instance, self._cache_attr)
 
+        self._check_deprecated_return_type(instance, value)
+
+        return value
+
+    def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None:
         if not _isinstance(value, self._var_type):
             console.error(
                 f"Computed var '{type(instance).__name__}.{self._js_expr}' must return"
                 f" type '{self._var_type}', got '{type(value)}'."
             )
 
-        return value
-
     def _deps(
         self,
-        objclass: Type,
+        objclass: Type[BaseState],
         obj: FunctionType | CodeType | None = None,
-        self_name: Optional[str] = None,
-    ) -> set[str]:
+    ) -> dict[str, set[str]]:
         """Determine var dependencies of this ComputedVar.
 
-        Save references to attributes accessed on "self".  Recursively called
-        when the function makes a method call on "self" or define comprehensions
-        or nested functions that may reference "self".
+        Save references to attributes accessed on "self" or other fetched states.
+
+        Recursively called when the function makes a method call on "self" or
+        define comprehensions or nested functions that may reference "self".
 
         Args:
             objclass: the class obj this ComputedVar is attached to.
             obj: the object to disassemble (defaults to the fget function).
-            self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
 
         Returns:
-            A set of variable names accessed by the given obj.
-
-        Raises:
-            VarValueError: if the function references the get_state, parent_state, or substates attributes
-                (cannot track deps in a related state, only implicitly via parent state).
+            A dictionary mapping state names to the set of variable names
+            accessed by the given obj.
         """
+        from .dep_tracking import DependencyTracker
+
+        d = {}
+        if self._static_deps:
+            d.update(self._static_deps)
+            # None is a placeholder for the current state class.
+            if None in d:
+                d[objclass.get_full_name()] = d.pop(None)
+
         if not self._auto_deps:
-            return self._static_deps
-        d = self._static_deps.copy()
+            return d
+
         if obj is None:
             fget = self._fget
             if fget is not None:
                 obj = cast(FunctionType, fget)
             else:
-                return set()
-        with contextlib.suppress(AttributeError):
-            # unbox functools.partial
-            obj = cast(FunctionType, obj.func)  # pyright: ignore [reportAttributeAccessIssue]
-        with contextlib.suppress(AttributeError):
-            # unbox EventHandler
-            obj = cast(FunctionType, obj.fn)  # pyright: ignore [reportAttributeAccessIssue]
+                return d
 
-        if self_name is None and isinstance(obj, FunctionType):
-            try:
-                # the first argument to the function is the name of "self" arg
-                self_name = obj.__code__.co_varnames[0]
-            except (AttributeError, IndexError):
-                self_name = None
-        if self_name is None:
-            # cannot reference attributes on self if method takes no args
-            return set()
-
-        invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
-        self_is_top_of_stack = False
-        for instruction in dis.get_instructions(obj):
-            if (
-                instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
-                and instruction.argval == self_name
-            ):
-                # bytecode loaded the class instance to the top of stack, next load instruction
-                # is referencing an attribute on self
-                self_is_top_of_stack = True
-                continue
-            if self_is_top_of_stack and instruction.opname in (
-                "LOAD_ATTR",
-                "LOAD_METHOD",
-            ):
-                try:
-                    ref_obj = getattr(objclass, instruction.argval)
-                except Exception:
-                    ref_obj = None
-                if instruction.argval in invalid_names:
-                    raise VarValueError(
-                        f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
-                    )
-                if callable(ref_obj):
-                    # recurse into callable attributes
-                    d.update(
-                        self._deps(
-                            objclass=objclass,
-                            obj=ref_obj,  # pyright: ignore [reportArgumentType]
-                        )
-                    )
-                # recurse into property fget functions
-                elif isinstance(ref_obj, property) and not isinstance(
-                    ref_obj, ComputedVar
-                ):
-                    d.update(
-                        self._deps(
-                            objclass=objclass,
-                            obj=ref_obj.fget,  # pyright: ignore [reportArgumentType]
-                        )
-                    )
-                elif (
-                    instruction.argval in objclass.backend_vars
-                    or instruction.argval in objclass.vars
-                ):
-                    # var access
-                    d.add(instruction.argval)
-            elif instruction.opname == "LOAD_CONST" and isinstance(
-                instruction.argval, CodeType
-            ):
-                # recurse into nested functions / comprehensions, which can reference
-                # instance attributes from the outer scope
-                d.update(
-                    self._deps(
-                        objclass=objclass,
-                        obj=instruction.argval,
-                        self_name=self_name,
-                    )
-                )
-            self_is_top_of_stack = False
-        return d
+        try:
+            return DependencyTracker(
+                func=obj, state_cls=objclass, dependencies=d
+            ).dependencies
+        except Exception as e:
+            console.warn(
+                "Failed to automatically determine dependencies for computed var "
+                f"{objclass.__name__}.{self._js_expr}: {e}. "
+                "Provide static_deps and set auto_deps=False to suppress this warning."
+            )
+            return d
 
     def mark_dirty(self, instance: BaseState) -> None:
         """Mark this ComputedVar as dirty.
@@ -2362,6 +2323,37 @@ class ComputedVar(Var[RETURN_TYPE]):
         with contextlib.suppress(AttributeError):
             delattr(instance, self._cache_attr)
 
+    def add_dependency(self, objclass: Type[BaseState], dep: Var):
+        """Explicitly add a dependency to the ComputedVar.
+
+        After adding the dependency, when the `dep` changes, this computed var
+        will be marked dirty.
+
+        Args:
+            objclass: The class obj this ComputedVar is attached to.
+            dep: The dependency to add.
+
+        Raises:
+            VarDependencyError: If the dependency is not a Var instance with a
+                state and field name
+        """
+        if all_var_data := dep._get_all_var_data():
+            state_name = all_var_data.state
+            if state_name:
+                var_name = all_var_data.field_name
+                if var_name:
+                    self._static_deps.setdefault(state_name, set()).add(var_name)
+                    objclass.get_root_state().get_class_substate(
+                        state_name
+                    )._var_dependencies.setdefault(var_name, set()).add(
+                        (objclass.get_full_name(), self._js_expr)
+                    )
+                    return
+        raise VarDependencyError(
+            "ComputedVar dependencies must be Var instances with a state and "
+            f"field name, got {dep!r}."
+        )
+
     def _determine_var_type(self) -> Type:
         """Get the type of the var.
 
@@ -2398,6 +2390,126 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
     pass
 
 
+async def _default_async_computed_var(_self: BaseState) -> Any:
+    return None
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    init=False,
+    slots=True,
+)
+class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
+    """A computed var that wraps a coroutinefunction."""
+
+    _fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = (
+        dataclasses.field(default=_default_async_computed_var)
+    )
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[bool],
+        instance: None,
+        owner: Type,
+    ) -> BooleanVar: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[int] | ComputedVar[float],
+        instance: None,
+        owner: Type,
+    ) -> NumberVar: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[str],
+        instance: None,
+        owner: Type,
+    ) -> StringVar: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]],
+        instance: None,
+        owner: Type,
+    ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[list[LIST_INSIDE]],
+        instance: None,
+        owner: Type,
+    ) -> ArrayVar[list[LIST_INSIDE]]: ...
+
+    @overload
+    def __get__(
+        self: AsyncComputedVar[tuple[LIST_INSIDE, ...]],
+        instance: None,
+        owner: Type,
+    ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
+
+    @overload
+    def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ...
+
+    @overload
+    def __get__(
+        self, instance: BaseState, owner: Type
+    ) -> Coroutine[None, None, RETURN_TYPE]: ...
+
+    def __get__(
+        self, instance: BaseState | None, owner
+    ) -> Var | Coroutine[None, None, RETURN_TYPE]:
+        """Get the ComputedVar value.
+
+        If the value is already cached on the instance, return the cached value.
+
+        Args:
+            instance: the instance of the class accessing this computed var.
+            owner: the class that this descriptor is attached to.
+
+        Returns:
+            The value of the var for the given instance.
+        """
+        if instance is None:
+            return super(AsyncComputedVar, self).__get__(instance, owner)
+
+        if not self._cache:
+
+            async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
+                value = await self.fget(instance)
+                self._check_deprecated_return_type(instance, value)
+                return value
+
+            return _awaitable_result()
+        else:
+            # handle caching
+            async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
+                if not hasattr(instance, self._cache_attr) or self.needs_update(
+                    instance
+                ):
+                    # Set cache attr on state instance.
+                    setattr(instance, self._cache_attr, await self.fget(instance))
+                    # Ensure the computed var gets serialized to redis.
+                    instance._was_touched = True
+                    # Set the last updated timestamp on the state instance.
+                    setattr(instance, self._last_updated_attr, datetime.datetime.now())
+                value = getattr(instance, self._cache_attr)
+                self._check_deprecated_return_type(instance, value)
+                return value
+
+            return _awaitable_result()
+
+    @property
+    def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]:
+        """Get the getter function.
+
+        Returns:
+            The getter function.
+        """
+        return self._fget
+
+
 if TYPE_CHECKING:
     BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
 
@@ -2464,10 +2576,27 @@ def computed_var(
         raise VarDependencyError("Cannot track dependencies without caching.")
 
     if fget is not None:
-        return ComputedVar(fget, cache=cache)
+        if inspect.iscoroutinefunction(fget):
+            computed_var_cls = AsyncComputedVar
+        else:
+            computed_var_cls = ComputedVar
+        return computed_var_cls(
+            fget,
+            initial_value=initial_value,
+            cache=cache,
+            deps=deps,
+            auto_deps=auto_deps,
+            interval=interval,
+            backend=backend,
+            **kwargs,
+        )
 
     def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
-        return ComputedVar(
+        if inspect.iscoroutinefunction(fget):
+            computed_var_cls = AsyncComputedVar
+        else:
+            computed_var_cls = ComputedVar
+        return computed_var_cls(
             fget,
             initial_value=initial_value,
             cache=cache,

+ 344 - 0
reflex/vars/dep_tracking.py

@@ -0,0 +1,344 @@
+"""Collection of base classes."""
+
+from __future__ import annotations
+
+import contextlib
+import dataclasses
+import dis
+import enum
+import inspect
+from types import CellType, CodeType, FunctionType
+from typing import TYPE_CHECKING, Any, ClassVar, Type, cast
+
+from reflex.utils.exceptions import VarValueError
+
+if TYPE_CHECKING:
+    from reflex.state import BaseState
+
+    from .base import Var
+
+
+CellEmpty = object()
+
+
+def get_cell_value(cell: CellType) -> Any:
+    """Get the value of a cell object.
+
+    Args:
+        cell: The cell object to get the value from. (func.__closure__ objects)
+
+    Returns:
+        The value from the cell or CellEmpty if a ValueError is raised.
+    """
+    try:
+        return cell.cell_contents
+    except ValueError:
+        return CellEmpty
+
+
+class ScanStatus(enum.Enum):
+    """State of the dis instruction scanning loop."""
+
+    SCANNING = enum.auto()
+    GETTING_ATTR = enum.auto()
+    GETTING_STATE = enum.auto()
+    GETTING_VAR = enum.auto()
+
+
+@dataclasses.dataclass
+class DependencyTracker:
+    """State machine for identifying state attributes that are accessed by a function."""
+
+    func: FunctionType | CodeType = dataclasses.field()
+    state_cls: Type[BaseState] = dataclasses.field()
+
+    dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
+
+    scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
+    top_of_stack: str | None = dataclasses.field(default=None)
+
+    tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict)
+
+    _getting_state_class: Type[BaseState] | None = dataclasses.field(default=None)
+    _getting_var_instructions: list[dis.Instruction] = dataclasses.field(
+        default_factory=list
+    )
+
+    INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
+
+    def __post_init__(self):
+        """After initializing, populate the dependencies dict."""
+        with contextlib.suppress(AttributeError):
+            # unbox functools.partial
+            self.func = cast(FunctionType, self.func.func)  # pyright: ignore[reportAttributeAccessIssue]
+        with contextlib.suppress(AttributeError):
+            # unbox EventHandler
+            self.func = cast(FunctionType, self.func.fn)  # pyright: ignore[reportAttributeAccessIssue]
+
+        if isinstance(self.func, FunctionType):
+            with contextlib.suppress(AttributeError, IndexError):
+                # the first argument to the function is the name of "self" arg
+                self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
+
+        self._populate_dependencies()
+
+    def _merge_deps(self, tracker: DependencyTracker) -> None:
+        """Merge dependencies from another DependencyTracker.
+
+        Args:
+            tracker: The DependencyTracker to merge dependencies from.
+        """
+        for state_name, dep_name in tracker.dependencies.items():
+            self.dependencies.setdefault(state_name, set()).update(dep_name)
+
+    def load_attr_or_method(self, instruction: dis.Instruction) -> None:
+        """Handle loading an attribute or method from the object on top of the stack.
+
+        This method directly tracks attributes and recursively merges
+        dependencies from analyzing the dependencies of any methods called.
+
+        Args:
+            instruction: The dis instruction to process.
+
+        Raises:
+            VarValueError: if the attribute is an disallowed name.
+        """
+        from .base import ComputedVar
+
+        if instruction.argval in self.INVALID_NAMES:
+            raise VarValueError(
+                f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
+            )
+        if instruction.argval == "get_state":
+            # Special case: arbitrary state access requested.
+            self.scan_status = ScanStatus.GETTING_STATE
+            return
+        if instruction.argval == "get_var_value":
+            # Special case: arbitrary var access requested.
+            self.scan_status = ScanStatus.GETTING_VAR
+            return
+
+        # Reset status back to SCANNING after attribute is accessed.
+        self.scan_status = ScanStatus.SCANNING
+        if not self.top_of_stack:
+            return
+        target_state = self.tracked_locals[self.top_of_stack]
+        try:
+            ref_obj = getattr(target_state, instruction.argval)
+        except AttributeError:
+            # Not found on this state class, maybe it is a dynamic attribute that will be picked up later.
+            ref_obj = None
+
+        if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
+            # recurse into property fget functions
+            ref_obj = ref_obj.fget
+        if callable(ref_obj):
+            # recurse into callable attributes
+            self._merge_deps(
+                type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
+            )
+        elif (
+            instruction.argval in target_state.backend_vars
+            or instruction.argval in target_state.vars
+        ):
+            # var access
+            self.dependencies.setdefault(target_state.get_full_name(), set()).add(
+                instruction.argval
+            )
+
+    def _get_globals(self) -> dict[str, Any]:
+        """Get the globals of the function.
+
+        Returns:
+            The var names and values in the globals of the function.
+        """
+        if isinstance(self.func, CodeType):
+            return {}
+        return self.func.__globals__  # pyright: ignore[reportAttributeAccessIssue]
+
+    def _get_closure(self) -> dict[str, Any]:
+        """Get the closure of the function, with unbound values omitted.
+
+        Returns:
+            The var names and values in the closure of the function.
+        """
+        if isinstance(self.func, CodeType):
+            return {}
+        return {
+            var_name: get_cell_value(cell)
+            for var_name, cell in zip(
+                self.func.__code__.co_freevars,  # pyright: ignore[reportAttributeAccessIssue]
+                self.func.__closure__ or (),
+                strict=False,
+            )
+            if get_cell_value(cell) is not CellEmpty
+        }
+
+    def handle_getting_state(self, instruction: dis.Instruction) -> None:
+        """Handle bytecode analysis when `get_state` was called in the function.
+
+        If the wrapped function is getting an arbitrary state and saving it to a
+        local variable, this method associates the local variable name with the
+        state class in self.tracked_locals.
+
+        When an attribute/method is accessed on a tracked local, it will be
+        associated with this state.
+
+        Args:
+            instruction: The dis instruction to process.
+
+        Raises:
+            VarValueError: if the state class cannot be determined from the instruction.
+        """
+        from reflex.state import BaseState
+
+        if instruction.opname == "LOAD_FAST":
+            raise VarValueError(
+                f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
+            )
+        if isinstance(self.func, CodeType):
+            raise VarValueError(
+                "Dependency detection cannot identify get_state class from a code object."
+            )
+        if instruction.opname == "LOAD_GLOBAL":
+            # Special case: referencing state class from global scope.
+            try:
+                self._getting_state_class = self._get_globals()[instruction.argval]
+            except (ValueError, KeyError) as ve:
+                raise VarValueError(
+                    f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
+                ) from ve
+        elif instruction.opname == "LOAD_DEREF":
+            # Special case: referencing state class from closure.
+            try:
+                self._getting_state_class = self._get_closure()[instruction.argval]
+            except (ValueError, KeyError) as ve:
+                raise VarValueError(
+                    f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
+                ) from ve
+        elif instruction.opname == "STORE_FAST":
+            # Storing the result of get_state in a local variable.
+            if not isinstance(self._getting_state_class, type) or not issubclass(
+                self._getting_state_class, BaseState
+            ):
+                raise VarValueError(
+                    f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
+                )
+            self.tracked_locals[instruction.argval] = self._getting_state_class
+            self.scan_status = ScanStatus.SCANNING
+            self._getting_state_class = None
+
+    def _eval_var(self) -> Var:
+        """Evaluate instructions from the wrapped function to get the Var object.
+
+        Returns:
+            The Var object.
+
+        Raises:
+            VarValueError: if the source code for the var cannot be determined.
+        """
+        # Get the original source code and eval it to get the Var.
+        module = inspect.getmodule(self.func)
+        positions0 = self._getting_var_instructions[0].positions
+        positions1 = self._getting_var_instructions[-1].positions
+        if module is None or positions0 is None or positions1 is None:
+            raise VarValueError(
+                f"Cannot determine the source code for the var in {self.func!r}."
+            )
+        start_line = positions0.lineno
+        start_column = positions0.col_offset
+        end_line = positions1.end_lineno
+        end_column = positions1.end_col_offset
+        if (
+            start_line is None
+            or start_column is None
+            or end_line is None
+            or end_column is None
+        ):
+            raise VarValueError(
+                f"Cannot determine the source code for the var in {self.func!r}."
+            )
+        source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line]
+        # Create a python source string snippet.
+        if len(source) > 1:
+            snipped_source = "".join(
+                [
+                    *source[0][start_column:],
+                    *(source[1:-2] if len(source) > 2 else []),
+                    *source[-1][: end_column - 1],
+                ]
+            )
+        else:
+            snipped_source = source[0][start_column : end_column - 1]
+        # Evaluate the string in the context of the function's globals and closure.
+        return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
+
+    def handle_getting_var(self, instruction: dis.Instruction) -> None:
+        """Handle bytecode analysis when `get_var_value` was called in the function.
+
+        This only really works if the expression passed to `get_var_value` is
+        evaluable in the function's global scope or closure, so getting the var
+        value from a var saved in a local variable or in the class instance is
+        not possible.
+
+        Args:
+            instruction: The dis instruction to process.
+
+        Raises:
+            VarValueError: if the source code for the var cannot be determined.
+        """
+        if instruction.opname == "CALL" and self._getting_var_instructions:
+            if self._getting_var_instructions:
+                the_var = self._eval_var()
+                the_var_data = the_var._get_all_var_data()
+                if the_var_data is None:
+                    raise VarValueError(
+                        f"Cannot determine the source code for the var in {self.func!r}."
+                    )
+                self.dependencies.setdefault(the_var_data.state, set()).add(
+                    the_var_data.field_name
+                )
+            self._getting_var_instructions.clear()
+            self.scan_status = ScanStatus.SCANNING
+        else:
+            self._getting_var_instructions.append(instruction)
+
+    def _populate_dependencies(self) -> None:
+        """Update self.dependencies based on the disassembly of self.func.
+
+        Save references to attributes accessed on "self" or other fetched states.
+
+        Recursively called when the function makes a method call on "self" or
+        define comprehensions or nested functions that may reference "self".
+        """
+        for instruction in dis.get_instructions(self.func):
+            if self.scan_status == ScanStatus.GETTING_STATE:
+                self.handle_getting_state(instruction)
+            elif self.scan_status == ScanStatus.GETTING_VAR:
+                self.handle_getting_var(instruction)
+            elif (
+                instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
+                and instruction.argval in self.tracked_locals
+            ):
+                # bytecode loaded the class instance to the top of stack, next load instruction
+                # is referencing an attribute on self
+                self.top_of_stack = instruction.argval
+                self.scan_status = ScanStatus.GETTING_ATTR
+            elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
+                "LOAD_ATTR",
+                "LOAD_METHOD",
+            ):
+                self.load_attr_or_method(instruction)
+                self.top_of_stack = None
+            elif instruction.opname == "LOAD_CONST" and isinstance(
+                instruction.argval, CodeType
+            ):
+                # recurse into nested functions / comprehensions, which can reference
+                # instance attributes from the outer scope
+                self._merge_deps(
+                    type(self)(
+                        func=instruction.argval,
+                        state_cls=self.state_cls,
+                        tracked_locals=self.tracked_locals,
+                    )
+                )

+ 7 - 5
tests/integration/tests_playwright/test_table.py

@@ -3,7 +3,7 @@
 from typing import Generator
 
 import pytest
-from playwright.sync_api import Page
+from playwright.sync_api import Page, expect
 
 from reflex.testing import AppHarness
 
@@ -87,12 +87,14 @@ def test_table(page: Page, table_app: AppHarness):
     table = page.get_by_role("table")
 
     # Check column headers
-    headers = table.get_by_role("columnheader").all_inner_texts()
-    assert headers == expected_col_headers
+    headers = table.get_by_role("columnheader")
+    for header, exp_value in zip(headers.all(), expected_col_headers, strict=True):
+        expect(header).to_have_text(exp_value)
 
     # Check rows headers
-    rows = table.get_by_role("rowheader").all_inner_texts()
-    assert rows == expected_row_headers
+    rows = table.get_by_role("rowheader")
+    for row, expected_row in zip(rows.all(), expected_row_headers, strict=True):
+        expect(row).to_have_text(expected_row)
 
     # Check cells
     rows = table.get_by_role("cell").all_inner_texts()

+ 14 - 4
tests/units/test_app.py

@@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
     assert app._pages.keys() == {"test/[dynamic]"}
     assert "dynamic" in app._state.computed_vars
     assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
-        constants.ROUTER
+        EmptyState.get_full_name(): {constants.ROUTER},
     }
-    assert constants.ROUTER in app._state()._computed_var_dependencies
+    assert constants.ROUTER in app._state()._var_dependencies
 
 
 def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
@@ -995,9 +995,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     assert arg_name in app._state.vars
     assert arg_name in app._state.computed_vars
     assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
-        constants.ROUTER
+        DynamicState.get_full_name(): {constants.ROUTER},
     }
-    assert constants.ROUTER in app._state()._computed_var_dependencies
+    assert constants.ROUTER in app._state()._var_dependencies
 
     substate_token = _substate_key(token, DynamicState)
     sid = "mock_sid"
@@ -1555,6 +1555,16 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
         def bar(self) -> str:
             return "bar"
 
+    class Child1(ValidDepState):
+        @computed_var(deps=["base", ValidDepState.bar])
+        def other(self) -> str:
+            return "other"
+
+    class Child2(ValidDepState):
+        @computed_var(deps=["base", Child1.other])
+        def other(self) -> str:
+            return "other"
+
     app._state = ValidDepState
     app._compile()
 

+ 174 - 38
tests/units/test_state.py

@@ -14,6 +14,7 @@ from typing import (
     Any,
     AsyncGenerator,
     Callable,
+    ClassVar,
     Dict,
     List,
     Optional,
@@ -1169,13 +1170,17 @@ def test_conditional_computed_vars():
 
     ms = MainState()
     # Initially there are no dirty computed vars.
-    assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
-    assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
-    assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
+    assert ms._dirty_computed_vars(from_vars={"flag"}) == {
+        (MainState.get_full_name(), "rendered_var")
+    }
+    assert ms._dirty_computed_vars(from_vars={"t2"}) == {
+        (MainState.get_full_name(), "rendered_var")
+    }
+    assert ms._dirty_computed_vars(from_vars={"t1"}) == {
+        (MainState.get_full_name(), "rendered_var")
+    }
     assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
-        "flag",
-        "t1",
-        "t2",
+        MainState.get_full_name(): {"flag", "t1", "t2"}
     }
 
 
@@ -1370,7 +1375,10 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
         assert isinstance(HandlerState.handler, EventHandler)
 
     s = HandlerState()
-    assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
+    assert (
+        HandlerState.get_full_name(),
+        "cached_x_side_effect",
+    ) in s._var_dependencies["x"]
     assert s.cached_x_side_effect == 1
     assert s.x == 43
     s.handler()
@@ -1460,15 +1468,15 @@ def test_computed_var_dependencies():
             return [z in self._z for z in range(5)]
 
     cs = ComputedState()
-    assert cs._computed_var_dependencies["v"] == {
-        "comp_v",
-        "comp_v_backend",
-        "comp_v_via_property",
+    assert cs._var_dependencies["v"] == {
+        (ComputedState.get_full_name(), "comp_v"),
+        (ComputedState.get_full_name(), "comp_v_backend"),
+        (ComputedState.get_full_name(), "comp_v_via_property"),
     }
-    assert cs._computed_var_dependencies["w"] == {"comp_w"}
-    assert cs._computed_var_dependencies["x"] == {"comp_x"}
-    assert cs._computed_var_dependencies["y"] == {"comp_y"}
-    assert cs._computed_var_dependencies["_z"] == {"comp_z"}
+    assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")}
+    assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")}
+    assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")}
+    assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")}
 
 
 def test_backend_method():
@@ -3180,7 +3188,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
 RxState = State
 
 
-def test_potentially_dirty_substates():
+def test_potentially_dirty_states():
     """Test that potentially_dirty_substates returns the correct substates.
 
     Even if the name "State" is shadowed, it should still work correctly.
@@ -3196,13 +3204,19 @@ def test_potentially_dirty_substates():
         def bar(self) -> str:
             return ""
 
-    assert RxState._potentially_dirty_substates() == set()
-    assert State._potentially_dirty_substates() == set()
-    assert C1._potentially_dirty_substates() == set()
+    assert RxState._get_potentially_dirty_states() == set()
+    assert State._get_potentially_dirty_states() == set()
+    assert C1._get_potentially_dirty_states() == set()
+
 
+@pytest.mark.asyncio
+async def test_router_var_dep(state_manager: StateManager, token: str) -> None:
+    """Test that router var dependencies are correctly tracked.
 
-def test_router_var_dep() -> None:
-    """Test that router var dependencies are correctly tracked."""
+    Args:
+        state_manager: A state manager.
+        token: A token.
+    """
 
     class RouterVarParentState(State):
         """A parent state for testing router var dependency."""
@@ -3219,30 +3233,27 @@ def test_router_var_dep() -> None:
     foo = RouterVarDepState.computed_vars["foo"]
     State._init_var_dependency_dicts()
 
-    assert foo._deps(objclass=RouterVarDepState) == {"router"}
-    assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
-    assert RouterVarParentState._substate_var_dependencies == {
-        "router": {RouterVarDepState.get_name()}
-    }
-    assert RouterVarDepState._computed_var_dependencies == {
-        "router": {"foo"},
+    assert foo._deps(objclass=RouterVarDepState) == {
+        RouterVarDepState.get_full_name(): {"router"}
     }
+    assert (RouterVarDepState.get_full_name(), "foo") in State._var_dependencies[
+        "router"
+    ]
 
-    rx_state = State()
-    parent_state = RouterVarParentState()
-    state = RouterVarDepState()
-
-    # link states
-    rx_state.substates = {RouterVarParentState.get_name(): parent_state}
-    parent_state.parent_state = rx_state
-    state.parent_state = parent_state
-    parent_state.substates = {RouterVarDepState.get_name(): state}
+    # Get state from state manager.
+    state_manager.state = State
+    rx_state = await state_manager.get_state(_substate_key(token, State))
+    assert RouterVarParentState.get_name() in rx_state.substates
+    parent_state = rx_state.substates[RouterVarParentState.get_name()]
+    assert RouterVarDepState.get_name() in parent_state.substates
+    state = parent_state.substates[RouterVarDepState.get_name()]
 
     assert state.dirty_vars == set()
 
     # Reassign router var
     state.router = state.router
-    assert state.dirty_vars == {"foo", "router"}
+    assert rx_state.dirty_vars == {"router"}
+    assert state.dirty_vars == {"foo"}
     assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
 
 
@@ -3801,3 +3812,128 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str):
     # Generic Var with no state
     with pytest.raises(UnretrievableVarValueError):
         await state.get_var_value(rx.Var("undefined"))
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
+    """A test where an async computed var depends on a var in another state.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+
+    class Parent(BaseState):
+        """A root state like rx.State."""
+
+        parent_var: int = 0
+
+    class Child2(Parent):
+        """An unconnected child state."""
+
+        pass
+
+    class Child3(Parent):
+        """A child state with a computed var causing it to be pre-fetched.
+
+        If child3_var gets set to a value, and `get_state` erroneously
+        re-fetches it from redis, the value will be lost.
+        """
+
+        child3_var: int = 0
+
+        @rx.var(cache=True)
+        def v(self) -> int:
+            return self.child3_var
+
+    class Child(Parent):
+        """A state simulating UpdateVarsInternalState."""
+
+        @rx.var(cache=True)
+        async def v(self) -> int:
+            p = await self.get_state(Parent)
+            child3 = await self.get_state(Child3)
+            return child3.child3_var + p.parent_var
+
+    mock_app.state_manager.state = mock_app._state = Parent
+
+    # Get the top level state via unconnected sibling.
+    root = await mock_app.state_manager.get_state(_substate_key(token, Child))
+    # Set value in parent_var to assert it does not get refetched later.
+    root.parent_var = 1
+
+    if isinstance(mock_app.state_manager, StateManagerRedis):
+        # When redis is used, only states with uncached computed vars are pre-fetched.
+        assert Child2.get_name() not in root.substates
+        assert Child3.get_name() not in root.substates
+
+    # Get the unconnected sibling state, which will be used to `get_state` other instances.
+    child = root.get_substate(Child.get_full_name().split("."))
+
+    # Get an uncached child state.
+    child2 = await child.get_state(Child2)
+    assert child2.parent_var == 1
+
+    # Set value on already-cached Child3 state (prefetched because it has a Computed Var).
+    child3 = await child.get_state(Child3)
+    child3.child3_var = 1
+
+    assert await child.v == 2
+    assert await child.v == 2
+    root.parent_var = 2
+    assert await child.v == 3
+
+
+class Table(rx.ComponentState):
+    """A table state."""
+
+    data: ClassVar[Var]
+
+    @rx.var(cache=True, auto_deps=False)
+    async def rows(self) -> List[Dict[str, Any]]:
+        """Computed var over the given rows.
+
+        Returns:
+            The data rows.
+        """
+        return await self.get_var_value(self.data)
+
+    @classmethod
+    def get_component(cls, data: Var) -> rx.Component:
+        """Get the component for the table.
+
+        Args:
+            data: The data var.
+
+        Returns:
+            The component.
+        """
+        cls.data = data
+        cls.computed_vars["rows"].add_dependency(cls, data)
+        return rx.foreach(data, lambda d: rx.text(d.to_string()))
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str):
+    """A test where an async computed var depends on a var in another state.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+
+    class OtherState(rx.State):
+        """A state with a var."""
+
+        data: List[Dict[str, Any]] = [{"foo": "bar"}]
+
+    mock_app.state_manager.state = mock_app._state = rx.State
+    comp = Table.create(data=OtherState.data)
+    state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
+    other_state = await state.get_state(OtherState)
+    assert comp.State is not None
+    comp_state = await state.get_state(comp.State)
+    assert comp_state.dirty_vars == set()
+
+    other_state.data.append({"foo": "baz"})
+    assert "rows" in comp_state.dirty_vars

+ 25 - 3
tests/units/test_var.py

@@ -1807,9 +1807,9 @@ def cv_fget(state: BaseState) -> int:
 @pytest.mark.parametrize(
     "deps,expected",
     [
-        (["a"], {"a"}),
-        (["b"], {"b"}),
-        ([ComputedVar(fget=cv_fget)], {"cv_fget"}),
+        (["a"], {None: {"a"}}),
+        (["b"], {None: {"b"}}),
+        ([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}),
     ],
 )
 def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
@@ -1857,6 +1857,28 @@ def test_to_string_operation():
     assert single_var._var_type == Email
 
 
+@pytest.mark.asyncio
+async def test_async_computed_var():
+    side_effect_counter = 0
+
+    class AsyncComputedVarState(BaseState):
+        v: int = 1
+
+        @computed_var(cache=True)
+        async def async_computed_var(self) -> int:
+            nonlocal side_effect_counter
+            side_effect_counter += 1
+            return self.v + 1
+
+    my_state = AsyncComputedVarState()
+    assert await my_state.async_computed_var == 2
+    assert await my_state.async_computed_var == 2
+    my_state.v = 2
+    assert await my_state.async_computed_var == 3
+    assert await my_state.async_computed_var == 3
+    assert side_effect_counter == 2
+
+
 def test_var_data_hooks():
     var_data_str = VarData(hooks="what")
     var_data_list = VarData(hooks=["what"])