|
@@ -15,7 +15,6 @@ import time
|
|
|
import typing
|
|
|
import uuid
|
|
|
from abc import ABC, abstractmethod
|
|
|
-from collections import defaultdict
|
|
|
from hashlib import md5
|
|
|
from pathlib import Path
|
|
|
from types import FunctionType, MethodType
|
|
@@ -329,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
|
|
|
)
|
|
|
|
|
|
|
|
|
+async def _resolve_delta(delta: Delta) -> Delta:
|
|
|
+ """Await all coroutines in the delta.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ delta: The delta to process.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The same delta dict with all coroutines resolved to their return value.
|
|
|
+ """
|
|
|
+ tasks = {}
|
|
|
+ for state_name, state_delta in delta.items():
|
|
|
+ for var_name, value in state_delta.items():
|
|
|
+ if asyncio.iscoroutine(value):
|
|
|
+ tasks[state_name, var_name] = asyncio.create_task(value)
|
|
|
+ for (state_name, var_name), task in tasks.items():
|
|
|
+ delta[state_name][var_name] = await task
|
|
|
+ return delta
|
|
|
+
|
|
|
+
|
|
|
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
"""The state of the app."""
|
|
|
|
|
@@ -356,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# A set of subclassses of this class.
|
|
|
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
|
|
|
|
|
|
- # Mapping of var name to set of computed variables that depend on it
|
|
|
- _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
|
|
-
|
|
|
- # Mapping of var name to set of substates that depend on it
|
|
|
- _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
|
|
+ # Mapping of var name to set of (state_full_name, var_name) that depend on it.
|
|
|
+ _var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {}
|
|
|
|
|
|
# Set of vars which always need to be recomputed
|
|
|
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
|
|
@@ -368,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# Set of substates which always need to be recomputed
|
|
|
_always_dirty_substates: ClassVar[Set[str]] = set()
|
|
|
|
|
|
+ # Set of states which might need to be recomputed if vars in this state change.
|
|
|
+ _potentially_dirty_states: ClassVar[Set[str]] = set()
|
|
|
+
|
|
|
# The parent state.
|
|
|
parent_state: Optional[BaseState] = None
|
|
|
|
|
@@ -519,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
|
|
|
# Reset dirty substate tracking for this class.
|
|
|
cls._always_dirty_substates = set()
|
|
|
+ cls._potentially_dirty_states = set()
|
|
|
|
|
|
# Get the parent vars.
|
|
|
parent_state = cls.get_parent_state()
|
|
@@ -622,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
setattr(cls, name, handler)
|
|
|
|
|
|
# Initialize per-class var dependency tracking.
|
|
|
- cls._computed_var_dependencies = defaultdict(set)
|
|
|
- cls._substate_var_dependencies = defaultdict(set)
|
|
|
+ cls._var_dependencies = {}
|
|
|
cls._init_var_dependency_dicts()
|
|
|
|
|
|
@staticmethod
|
|
@@ -768,26 +786,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
Additional updates tracking dicts for vars and substates that always
|
|
|
need to be recomputed.
|
|
|
"""
|
|
|
- inherited_vars = set(cls.inherited_vars).union(
|
|
|
- set(cls.inherited_backend_vars),
|
|
|
- )
|
|
|
for cvar_name, cvar in cls.computed_vars.items():
|
|
|
- # Add the dependencies.
|
|
|
- for var in cvar._deps(objclass=cls):
|
|
|
- cls._computed_var_dependencies[var].add(cvar_name)
|
|
|
- if var in inherited_vars:
|
|
|
- # track that this substate depends on its parent for this var
|
|
|
- state_name = cls.get_name()
|
|
|
- parent_state = cls.get_parent_state()
|
|
|
- while parent_state is not None and var in {
|
|
|
- **parent_state.vars,
|
|
|
- **parent_state.backend_vars,
|
|
|
+ if not cvar._cache:
|
|
|
+ # Do not perform dep calculation when cache=False (these are always dirty).
|
|
|
+ continue
|
|
|
+ for state_name, dvar_set in cvar._deps(objclass=cls).items():
|
|
|
+ state_cls = cls.get_root_state().get_class_substate(state_name)
|
|
|
+ for dvar in dvar_set:
|
|
|
+ defining_state_cls = state_cls
|
|
|
+ while dvar in {
|
|
|
+ *defining_state_cls.inherited_vars,
|
|
|
+ *defining_state_cls.inherited_backend_vars,
|
|
|
}:
|
|
|
- parent_state._substate_var_dependencies[var].add(state_name)
|
|
|
- state_name, parent_state = (
|
|
|
- parent_state.get_name(),
|
|
|
- parent_state.get_parent_state(),
|
|
|
- )
|
|
|
+ parent_state = defining_state_cls.get_parent_state()
|
|
|
+ if parent_state is not None:
|
|
|
+ defining_state_cls = parent_state
|
|
|
+ defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
|
|
|
+ (cls.get_full_name(), cvar_name)
|
|
|
+ )
|
|
|
+ defining_state_cls._potentially_dirty_states.add(
|
|
|
+ cls.get_full_name()
|
|
|
+ )
|
|
|
|
|
|
# ComputedVar with cache=False always need to be recomputed
|
|
|
cls._always_dirty_computed_vars = {
|
|
@@ -902,6 +921,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
raise ValueError(f"Only one parent state is allowed {parent_states}.")
|
|
|
return parent_states[0] if len(parent_states) == 1 else None
|
|
|
|
|
|
+ @classmethod
|
|
|
+ @functools.lru_cache()
|
|
|
+ def get_root_state(cls) -> Type[BaseState]:
|
|
|
+ """Get the root state.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The root state.
|
|
|
+ """
|
|
|
+ parent_state = cls.get_parent_state()
|
|
|
+ return cls if parent_state is None else parent_state.get_root_state()
|
|
|
+
|
|
|
@classmethod
|
|
|
def get_substates(cls) -> set[Type[BaseState]]:
|
|
|
"""Get the substates of the state.
|
|
@@ -1351,7 +1381,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
super().__setattr__(name, value)
|
|
|
|
|
|
# Add the var to the dirty list.
|
|
|
- if name in self.vars or name in self._computed_var_dependencies:
|
|
|
+ if name in self.base_vars:
|
|
|
self.dirty_vars.add(name)
|
|
|
self._mark_dirty()
|
|
|
|
|
@@ -1422,64 +1452,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
return self.substates[path[0]].get_substate(path[1:])
|
|
|
|
|
|
@classmethod
|
|
|
- def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
|
|
|
- """Find the name of the nearest common ancestor shared by this and the other state.
|
|
|
-
|
|
|
- Args:
|
|
|
- other: The other state.
|
|
|
-
|
|
|
- Returns:
|
|
|
- Full name of the nearest common ancestor.
|
|
|
- """
|
|
|
- common_ancestor_parts = []
|
|
|
- for part1, part2 in zip(
|
|
|
- cls.get_full_name().split("."),
|
|
|
- other.get_full_name().split("."),
|
|
|
- strict=True,
|
|
|
- ):
|
|
|
- if part1 != part2:
|
|
|
- break
|
|
|
- common_ancestor_parts.append(part1)
|
|
|
- return ".".join(common_ancestor_parts)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def _determine_missing_parent_states(
|
|
|
- cls, target_state_cls: Type[BaseState]
|
|
|
- ) -> tuple[str, list[str]]:
|
|
|
- """Determine the missing parent states between the target_state_cls and common ancestor of this state.
|
|
|
-
|
|
|
- Args:
|
|
|
- target_state_cls: The class of the state to find missing parent states for.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The name of the common ancestor and the list of missing parent states.
|
|
|
- """
|
|
|
- common_ancestor_name = cls._get_common_ancestor(target_state_cls)
|
|
|
- common_ancestor_parts = common_ancestor_name.split(".")
|
|
|
- target_state_parts = tuple(target_state_cls.get_full_name().split("."))
|
|
|
- relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
|
|
|
-
|
|
|
- # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
|
|
|
- fetch_parent_states = [common_ancestor_name]
|
|
|
- for relative_parent_state_name in relative_target_state_parts:
|
|
|
- fetch_parent_states.append(
|
|
|
- ".".join((fetch_parent_states[-1], relative_parent_state_name))
|
|
|
- )
|
|
|
-
|
|
|
- return common_ancestor_name, fetch_parent_states[1:-1]
|
|
|
-
|
|
|
- def _get_parent_states(self) -> list[tuple[str, BaseState]]:
|
|
|
- """Get all parent state instances up to the root of the state tree.
|
|
|
+ def _get_potentially_dirty_states(cls) -> set[type[BaseState]]:
|
|
|
+ """Get substates which may have dirty vars due to dependencies.
|
|
|
|
|
|
Returns:
|
|
|
- A list of tuples containing the name and the instance of each parent state.
|
|
|
+ The set of potentially dirty substate classes.
|
|
|
"""
|
|
|
- parent_states_with_name = []
|
|
|
- parent_state = self
|
|
|
- while parent_state.parent_state is not None:
|
|
|
- parent_state = parent_state.parent_state
|
|
|
- parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
|
|
- return parent_states_with_name
|
|
|
+ return {
|
|
|
+ cls.get_class_substate(substate_name)
|
|
|
+ for substate_name in cls._always_dirty_substates
|
|
|
+ }.union(
|
|
|
+ {
|
|
|
+ cls.get_root_state().get_class_substate(substate_name)
|
|
|
+ for substate_name in cls._potentially_dirty_states
|
|
|
+ }
|
|
|
+ )
|
|
|
|
|
|
def _get_root_state(self) -> BaseState:
|
|
|
"""Get the root state of the state tree.
|
|
@@ -1492,55 +1479,38 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
parent_state = parent_state.parent_state
|
|
|
return parent_state
|
|
|
|
|
|
- async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
|
|
|
- """Populate substates in the tree between the target_state_cls and common ancestor of this state.
|
|
|
+ async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
|
|
|
+ """Get a state instance from redis.
|
|
|
|
|
|
Args:
|
|
|
- target_state_cls: The class of the state to populate parent states for.
|
|
|
+ state_cls: The class of the state.
|
|
|
|
|
|
Returns:
|
|
|
- The parent state instance of target_state_cls.
|
|
|
+ The instance of state_cls associated with this state's client_token.
|
|
|
|
|
|
Raises:
|
|
|
RuntimeError: If redis is not used in this backend process.
|
|
|
+ StateMismatchError: If the state instance is not of the expected type.
|
|
|
"""
|
|
|
+ # Then get the target state and all its substates.
|
|
|
state_manager = get_state_manager()
|
|
|
if not isinstance(state_manager, StateManagerRedis):
|
|
|
raise RuntimeError(
|
|
|
- f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
|
|
|
+ f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
|
|
"(All states should already be available -- this is likely a bug).",
|
|
|
)
|
|
|
+ state_in_redis = await state_manager.get_state(
|
|
|
+ token=_substate_key(self.router.session.client_token, state_cls),
|
|
|
+ top_level=False,
|
|
|
+ for_state_instance=self,
|
|
|
+ )
|
|
|
|
|
|
- # Find the missing parent states up to the common ancestor.
|
|
|
- (
|
|
|
- common_ancestor_name,
|
|
|
- missing_parent_states,
|
|
|
- ) = self._determine_missing_parent_states(target_state_cls)
|
|
|
-
|
|
|
- # Fetch all missing parent states and link them up to the common ancestor.
|
|
|
- parent_states_tuple = self._get_parent_states()
|
|
|
- root_state = parent_states_tuple[-1][1]
|
|
|
- parent_states_by_name = dict(parent_states_tuple)
|
|
|
- parent_state = parent_states_by_name[common_ancestor_name]
|
|
|
- for parent_state_name in missing_parent_states:
|
|
|
- try:
|
|
|
- parent_state = root_state.get_substate(parent_state_name.split("."))
|
|
|
- # The requested state is already cached, do NOT fetch it again.
|
|
|
- continue
|
|
|
- except ValueError:
|
|
|
- # The requested state is missing, fetch from redis.
|
|
|
- pass
|
|
|
- parent_state = await state_manager.get_state(
|
|
|
- token=_substate_key(
|
|
|
- self.router.session.client_token, parent_state_name
|
|
|
- ),
|
|
|
- top_level=False,
|
|
|
- get_substates=False,
|
|
|
- parent_state=parent_state,
|
|
|
+ if not isinstance(state_in_redis, state_cls):
|
|
|
+ raise StateMismatchError(
|
|
|
+ f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
|
|
|
)
|
|
|
|
|
|
- # Return the direct parent of target_state_cls for subsequent linking.
|
|
|
- return parent_state
|
|
|
+ return state_in_redis
|
|
|
|
|
|
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
|
|
|
"""Get a state instance from the cache.
|
|
@@ -1562,44 +1532,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
)
|
|
|
return substate
|
|
|
|
|
|
- async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
|
|
|
- """Get a state instance from redis.
|
|
|
-
|
|
|
- Args:
|
|
|
- state_cls: The class of the state.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The instance of state_cls associated with this state's client_token.
|
|
|
-
|
|
|
- Raises:
|
|
|
- RuntimeError: If redis is not used in this backend process.
|
|
|
- StateMismatchError: If the state instance is not of the expected type.
|
|
|
- """
|
|
|
- # Fetch all missing parent states from redis.
|
|
|
- parent_state_of_state_cls = await self._populate_parent_states(state_cls)
|
|
|
-
|
|
|
- # Then get the target state and all its substates.
|
|
|
- state_manager = get_state_manager()
|
|
|
- if not isinstance(state_manager, StateManagerRedis):
|
|
|
- raise RuntimeError(
|
|
|
- f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
|
|
- "(All states should already be available -- this is likely a bug).",
|
|
|
- )
|
|
|
-
|
|
|
- state_in_redis = await state_manager.get_state(
|
|
|
- token=_substate_key(self.router.session.client_token, state_cls),
|
|
|
- top_level=False,
|
|
|
- get_substates=True,
|
|
|
- parent_state=parent_state_of_state_cls,
|
|
|
- )
|
|
|
-
|
|
|
- if not isinstance(state_in_redis, state_cls):
|
|
|
- raise StateMismatchError(
|
|
|
- f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
|
|
|
- )
|
|
|
-
|
|
|
- return state_in_redis
|
|
|
-
|
|
|
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
|
|
|
"""Get an instance of the state associated with this token.
|
|
|
|
|
@@ -1738,7 +1670,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
|
|
)
|
|
|
|
|
|
- def _as_state_update(
|
|
|
+ async def _as_state_update(
|
|
|
self,
|
|
|
handler: EventHandler,
|
|
|
events: EventSpec | list[EventSpec] | None,
|
|
@@ -1766,7 +1698,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
|
|
|
try:
|
|
|
# Get the delta after processing the event.
|
|
|
- delta = state.get_delta()
|
|
|
+ delta = await _resolve_delta(state.get_delta())
|
|
|
state._clean()
|
|
|
|
|
|
return StateUpdate(
|
|
@@ -1866,24 +1798,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# Handle async generators.
|
|
|
if inspect.isasyncgen(events):
|
|
|
async for event in events:
|
|
|
- yield state._as_state_update(handler, event, final=False)
|
|
|
- yield state._as_state_update(handler, events=None, final=True)
|
|
|
+ yield await state._as_state_update(handler, event, final=False)
|
|
|
+ yield await state._as_state_update(handler, events=None, final=True)
|
|
|
|
|
|
# Handle regular generators.
|
|
|
elif inspect.isgenerator(events):
|
|
|
try:
|
|
|
while True:
|
|
|
- yield state._as_state_update(handler, next(events), final=False)
|
|
|
+ yield await state._as_state_update(
|
|
|
+ handler, next(events), final=False
|
|
|
+ )
|
|
|
except StopIteration as si:
|
|
|
# the "return" value of the generator is not available
|
|
|
# in the loop, we must catch StopIteration to access it
|
|
|
if si.value is not None:
|
|
|
- yield state._as_state_update(handler, si.value, final=False)
|
|
|
- yield state._as_state_update(handler, events=None, final=True)
|
|
|
+ yield await state._as_state_update(
|
|
|
+ handler, si.value, final=False
|
|
|
+ )
|
|
|
+ yield await state._as_state_update(handler, events=None, final=True)
|
|
|
|
|
|
# Handle regular event chains.
|
|
|
else:
|
|
|
- yield state._as_state_update(handler, events, final=True)
|
|
|
+ yield await state._as_state_update(handler, events, final=True)
|
|
|
|
|
|
# If an error occurs, throw a window alert.
|
|
|
except Exception as ex:
|
|
@@ -1893,7 +1829,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
|
|
|
)
|
|
|
|
|
|
- yield state._as_state_update(
|
|
|
+ yield await state._as_state_update(
|
|
|
handler,
|
|
|
event_specs,
|
|
|
final=True,
|
|
@@ -1901,15 +1837,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
|
|
|
def _mark_dirty_computed_vars(self) -> None:
|
|
|
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
|
|
+ # Append expired computed vars to dirty_vars to trigger recalculation
|
|
|
+ self.dirty_vars.update(self._expired_computed_vars())
|
|
|
+ # Append always dirty computed vars to dirty_vars to trigger recalculation
|
|
|
+ self.dirty_vars.update(self._always_dirty_computed_vars)
|
|
|
+
|
|
|
dirty_vars = self.dirty_vars
|
|
|
while dirty_vars:
|
|
|
calc_vars, dirty_vars = dirty_vars, set()
|
|
|
- for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
|
|
- self.dirty_vars.add(cvar)
|
|
|
+ for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
|
|
+ if state_name == self.get_full_name():
|
|
|
+ defining_state = self
|
|
|
+ else:
|
|
|
+ defining_state = self._get_root_state().get_substate(
|
|
|
+ tuple(state_name.split("."))
|
|
|
+ )
|
|
|
+ defining_state.dirty_vars.add(cvar)
|
|
|
dirty_vars.add(cvar)
|
|
|
- actual_var = self.computed_vars.get(cvar)
|
|
|
+ actual_var = defining_state.computed_vars.get(cvar)
|
|
|
if actual_var is not None:
|
|
|
- actual_var.mark_dirty(instance=self)
|
|
|
+ actual_var.mark_dirty(instance=defining_state)
|
|
|
+ if defining_state is not self:
|
|
|
+ defining_state._mark_dirty()
|
|
|
|
|
|
def _expired_computed_vars(self) -> set[str]:
|
|
|
"""Determine ComputedVars that need to be recalculated based on the expiration time.
|
|
@@ -1925,7 +1874,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
|
|
|
def _dirty_computed_vars(
|
|
|
self, from_vars: set[str] | None = None, include_backend: bool = True
|
|
|
- ) -> set[str]:
|
|
|
+ ) -> set[tuple[str, str]]:
|
|
|
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
|
|
|
|
|
Args:
|
|
@@ -1936,33 +1885,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
Set of computed vars to include in the delta.
|
|
|
"""
|
|
|
return {
|
|
|
- cvar
|
|
|
+ (state_name, cvar)
|
|
|
for dirty_var in from_vars or self.dirty_vars
|
|
|
- for cvar in self._computed_var_dependencies[dirty_var]
|
|
|
+ for state_name, cvar in self._var_dependencies.get(dirty_var, set())
|
|
|
if include_backend or not self.computed_vars[cvar]._backend
|
|
|
}
|
|
|
|
|
|
- @classmethod
|
|
|
- def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
|
|
- """Determine substates which could be affected by dirty vars in this state.
|
|
|
-
|
|
|
- Returns:
|
|
|
- Set of State classes that may need to be fetched to recalc computed vars.
|
|
|
- """
|
|
|
- # _always_dirty_substates need to be fetched to recalc computed vars.
|
|
|
- fetch_substates = {
|
|
|
- cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
|
- for substate_name in cls._always_dirty_substates
|
|
|
- }
|
|
|
- for dependent_substates in cls._substate_var_dependencies.values():
|
|
|
- fetch_substates.update(
|
|
|
- {
|
|
|
- cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
|
- for substate_name in dependent_substates
|
|
|
- }
|
|
|
- )
|
|
|
- return fetch_substates
|
|
|
-
|
|
|
def get_delta(self) -> Delta:
|
|
|
"""Get the delta for the state.
|
|
|
|
|
@@ -1971,21 +1899,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
"""
|
|
|
delta = {}
|
|
|
|
|
|
- # Apply dirty variables down into substates
|
|
|
- self.dirty_vars.update(self._always_dirty_computed_vars)
|
|
|
- self._mark_dirty()
|
|
|
-
|
|
|
+ self._mark_dirty_computed_vars()
|
|
|
frontend_computed_vars: set[str] = {
|
|
|
name for name, cv in self.computed_vars.items() if not cv._backend
|
|
|
}
|
|
|
|
|
|
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
|
|
# and always dirty computed vars (cache=False)
|
|
|
- delta_vars = (
|
|
|
- self.dirty_vars.intersection(self.base_vars)
|
|
|
- .union(self.dirty_vars.intersection(frontend_computed_vars))
|
|
|
- .union(self._dirty_computed_vars(include_backend=False))
|
|
|
- .union(self._always_dirty_computed_vars)
|
|
|
+ delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
|
|
+ self.dirty_vars.intersection(frontend_computed_vars)
|
|
|
)
|
|
|
|
|
|
subdelta: Dict[str, Any] = {
|
|
@@ -2015,23 +1937,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
self.parent_state.dirty_substates.add(self.get_name())
|
|
|
self.parent_state._mark_dirty()
|
|
|
|
|
|
- # Append expired computed vars to dirty_vars to trigger recalculation
|
|
|
- self.dirty_vars.update(self._expired_computed_vars())
|
|
|
-
|
|
|
# have to mark computed vars dirty to allow access to newly computed
|
|
|
# values within the same ComputedVar function
|
|
|
self._mark_dirty_computed_vars()
|
|
|
- self._mark_dirty_substates()
|
|
|
-
|
|
|
- def _mark_dirty_substates(self):
|
|
|
- """Propagate dirty var / computed var status into substates."""
|
|
|
- substates = self.substates
|
|
|
- for var in self.dirty_vars:
|
|
|
- for substate_name in self._substate_var_dependencies[var]:
|
|
|
- self.dirty_substates.add(substate_name)
|
|
|
- substate = substates[substate_name]
|
|
|
- substate.dirty_vars.add(var)
|
|
|
- substate._mark_dirty()
|
|
|
|
|
|
def _update_was_touched(self):
|
|
|
"""Update the _was_touched flag based on dirty_vars."""
|
|
@@ -2103,11 +2011,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
The object as a dictionary.
|
|
|
"""
|
|
|
if include_computed:
|
|
|
- # Apply dirty variables down into substates to allow never-cached ComputedVar to
|
|
|
- # trigger recalculation of dependent vars
|
|
|
- self.dirty_vars.update(self._always_dirty_computed_vars)
|
|
|
- self._mark_dirty()
|
|
|
-
|
|
|
+ self._mark_dirty_computed_vars()
|
|
|
base_vars = {
|
|
|
prop_name: self.get_value(prop_name) for prop_name in self.base_vars
|
|
|
}
|
|
@@ -2824,7 +2728,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|
|
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
|
|
|
)
|
|
|
|
|
|
- def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
|
|
+ async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
|
|
"""Temporarily allow mutability to access parent_state.
|
|
|
|
|
|
Args:
|
|
@@ -2837,7 +2741,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|
|
original_mutable = self._self_mutable
|
|
|
self._self_mutable = True
|
|
|
try:
|
|
|
- return self.__wrapped__._as_state_update(*args, **kwargs)
|
|
|
+ return await self.__wrapped__._as_state_update(*args, **kwargs)
|
|
|
finally:
|
|
|
self._self_mutable = original_mutable
|
|
|
|
|
@@ -3313,103 +3217,106 @@ class StateManagerRedis(StateManager):
|
|
|
b"evicted",
|
|
|
}
|
|
|
|
|
|
- async def _get_parent_state(
|
|
|
- self, token: str, state: BaseState | None = None
|
|
|
- ) -> BaseState | None:
|
|
|
- """Get the parent state for the state requested in the token.
|
|
|
+ def _get_required_state_classes(
|
|
|
+ self,
|
|
|
+ target_state_cls: Type[BaseState],
|
|
|
+ subclasses: bool = False,
|
|
|
+ required_state_classes: set[Type[BaseState]] | None = None,
|
|
|
+ ) -> set[Type[BaseState]]:
|
|
|
+ """Recursively determine which states are required to fetch the target state.
|
|
|
+
|
|
|
+ This will always include potentially dirty substates that depend on vars
|
|
|
+ in the target_state_cls.
|
|
|
|
|
|
Args:
|
|
|
- token: The token to get the state for (_substate_key).
|
|
|
- state: The state instance to get parent state for.
|
|
|
+ target_state_cls: The target state class being fetched.
|
|
|
+ subclasses: Whether to include subclasses of the target state.
|
|
|
+ required_state_classes: Recursive argument tracking state classes that have already been seen.
|
|
|
|
|
|
Returns:
|
|
|
- The parent state for the state requested by the token or None if there is no such parent.
|
|
|
- """
|
|
|
- parent_state = None
|
|
|
- client_token, state_path = _split_substate_key(token)
|
|
|
- parent_state_name = state_path.rpartition(".")[0]
|
|
|
- if parent_state_name:
|
|
|
- cached_substates = None
|
|
|
- if state is not None:
|
|
|
- cached_substates = [state]
|
|
|
- # Retrieve the parent state to populate event handlers onto this substate.
|
|
|
- parent_state = await self.get_state(
|
|
|
- token=_substate_key(client_token, parent_state_name),
|
|
|
- top_level=False,
|
|
|
- get_substates=False,
|
|
|
- cached_substates=cached_substates,
|
|
|
+ The set of state classes required to fetch the target state.
|
|
|
+ """
|
|
|
+ if required_state_classes is None:
|
|
|
+ required_state_classes = set()
|
|
|
+ # Get the substates if requested.
|
|
|
+ if subclasses:
|
|
|
+ for substate in target_state_cls.get_substates():
|
|
|
+ self._get_required_state_classes(
|
|
|
+ substate,
|
|
|
+ subclasses=True,
|
|
|
+ required_state_classes=required_state_classes,
|
|
|
+ )
|
|
|
+ if target_state_cls in required_state_classes:
|
|
|
+ return required_state_classes
|
|
|
+ required_state_classes.add(target_state_cls)
|
|
|
+
|
|
|
+ # Get dependent substates.
|
|
|
+ for pd_substates in target_state_cls._get_potentially_dirty_states():
|
|
|
+ self._get_required_state_classes(
|
|
|
+ pd_substates,
|
|
|
+ subclasses=False,
|
|
|
+ required_state_classes=required_state_classes,
|
|
|
)
|
|
|
- return parent_state
|
|
|
|
|
|
- async def _populate_substates(
|
|
|
- self,
|
|
|
- token: str,
|
|
|
- state: BaseState,
|
|
|
- all_substates: bool = False,
|
|
|
- ):
|
|
|
- """Fetch and link substates for the given state instance.
|
|
|
+ # Get the parent state if it exists.
|
|
|
+ if parent_state := target_state_cls.get_parent_state():
|
|
|
+ self._get_required_state_classes(
|
|
|
+ parent_state,
|
|
|
+ subclasses=False,
|
|
|
+ required_state_classes=required_state_classes,
|
|
|
+ )
|
|
|
+ return required_state_classes
|
|
|
|
|
|
- There is no return value; the side-effect is that `state` will have `substates` populated,
|
|
|
- and each substate will have its `parent_state` set to `state`.
|
|
|
+ def _get_populated_states(
|
|
|
+ self,
|
|
|
+ target_state: BaseState,
|
|
|
+ populated_states: dict[str, BaseState] | None = None,
|
|
|
+ ) -> dict[str, BaseState]:
|
|
|
+ """Recursively determine which states from target_state are already fetched.
|
|
|
|
|
|
Args:
|
|
|
- token: The token to get the state for.
|
|
|
- state: The state instance to populate substates for.
|
|
|
- all_substates: Whether to fetch all substates or just required substates.
|
|
|
- """
|
|
|
- client_token, _ = _split_substate_key(token)
|
|
|
+ target_state: The state to check for populated states.
|
|
|
+ populated_states: Recursive argument tracking states seen in previous calls.
|
|
|
|
|
|
- if all_substates:
|
|
|
- # All substates are requested.
|
|
|
- fetch_substates = state.get_substates()
|
|
|
- else:
|
|
|
- # Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
|
|
- fetch_substates = state._potentially_dirty_substates()
|
|
|
-
|
|
|
- tasks = {}
|
|
|
- # Retrieve the necessary substates from redis.
|
|
|
- for substate_cls in fetch_substates:
|
|
|
- if substate_cls.get_name() in state.substates:
|
|
|
- continue
|
|
|
- substate_name = substate_cls.get_name()
|
|
|
- tasks[substate_name] = asyncio.create_task(
|
|
|
- self.get_state(
|
|
|
- token=_substate_key(client_token, substate_cls),
|
|
|
- top_level=False,
|
|
|
- get_substates=all_substates,
|
|
|
- parent_state=state,
|
|
|
- )
|
|
|
+ Returns:
|
|
|
+ A dictionary of state full name to state instance.
|
|
|
+ """
|
|
|
+ if populated_states is None:
|
|
|
+ populated_states = {}
|
|
|
+ if target_state.get_full_name() in populated_states:
|
|
|
+ return populated_states
|
|
|
+ populated_states[target_state.get_full_name()] = target_state
|
|
|
+ for substate in target_state.substates.values():
|
|
|
+ self._get_populated_states(substate, populated_states=populated_states)
|
|
|
+ if target_state.parent_state is not None:
|
|
|
+ self._get_populated_states(
|
|
|
+ target_state.parent_state, populated_states=populated_states
|
|
|
)
|
|
|
-
|
|
|
- for substate_name, substate_task in tasks.items():
|
|
|
- state.substates[substate_name] = await substate_task
|
|
|
+ return populated_states
|
|
|
|
|
|
@override
|
|
|
async def get_state(
|
|
|
self,
|
|
|
token: str,
|
|
|
top_level: bool = True,
|
|
|
- get_substates: bool = True,
|
|
|
- parent_state: BaseState | None = None,
|
|
|
- cached_substates: list[BaseState] | None = None,
|
|
|
+ for_state_instance: BaseState | None = None,
|
|
|
) -> BaseState:
|
|
|
"""Get the state for a token.
|
|
|
|
|
|
Args:
|
|
|
token: The token to get the state for.
|
|
|
top_level: If true, return an instance of the top-level state (self.state).
|
|
|
- get_substates: If true, also retrieve substates.
|
|
|
- parent_state: If provided, use this parent_state instead of getting it from redis.
|
|
|
- cached_substates: If provided, attach these substates to the state.
|
|
|
+ for_state_instance: If provided, attach the requested states to this existing state tree.
|
|
|
|
|
|
Returns:
|
|
|
The state for the token.
|
|
|
|
|
|
Raises:
|
|
|
- RuntimeError: when the state_cls is not specified in the token
|
|
|
+ RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
|
|
|
+ requested state was not fetched.
|
|
|
"""
|
|
|
# Split the actual token from the fully qualified substate name.
|
|
|
- _, state_path = _split_substate_key(token)
|
|
|
+ token, state_path = _split_substate_key(token)
|
|
|
if state_path:
|
|
|
# Get the State class associated with the given path.
|
|
|
state_cls = self.state.get_class_substate(state_path)
|
|
@@ -3418,43 +3325,59 @@ class StateManagerRedis(StateManager):
|
|
|
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
|
|
)
|
|
|
|
|
|
- # The deserialized or newly created (sub)state instance.
|
|
|
- state = None
|
|
|
-
|
|
|
- # Fetch the serialized substate from redis.
|
|
|
- redis_state = await self.redis.get(token)
|
|
|
-
|
|
|
- if redis_state is not None:
|
|
|
- # Deserialize the substate.
|
|
|
- with contextlib.suppress(StateSchemaMismatchError):
|
|
|
- state = BaseState._deserialize(data=redis_state)
|
|
|
- if state is None:
|
|
|
- # Key didn't exist or schema mismatch so create a new instance for this token.
|
|
|
- state = state_cls(
|
|
|
- init_substates=False,
|
|
|
- _reflex_internal_init=True,
|
|
|
- )
|
|
|
- # Populate parent state if missing and requested.
|
|
|
- if parent_state is None:
|
|
|
- parent_state = await self._get_parent_state(token, state)
|
|
|
- # Set up Bidirectional linkage between this state and its parent.
|
|
|
- if parent_state is not None:
|
|
|
- parent_state.substates[state.get_name()] = state
|
|
|
- state.parent_state = parent_state
|
|
|
- # Avoid fetching substates multiple times.
|
|
|
- if cached_substates:
|
|
|
- for substate in cached_substates:
|
|
|
- state.substates[substate.get_name()] = substate
|
|
|
- if substate.parent_state is None:
|
|
|
- substate.parent_state = state
|
|
|
- # Populate substates if requested.
|
|
|
- await self._populate_substates(token, state, all_substates=get_substates)
|
|
|
+ # Determine which states we already have.
|
|
|
+ flat_state_tree: dict[str, BaseState] = (
|
|
|
+ self._get_populated_states(for_state_instance) if for_state_instance else {}
|
|
|
+ )
|
|
|
+
|
|
|
+ # Determine which states from the tree need to be fetched.
|
|
|
+ required_state_classes = sorted(
|
|
|
+ self._get_required_state_classes(state_cls, subclasses=True)
|
|
|
+ - {type(s) for s in flat_state_tree.values()},
|
|
|
+ key=lambda x: x.get_full_name(),
|
|
|
+ )
|
|
|
+
|
|
|
+ redis_pipeline = self.redis.pipeline()
|
|
|
+ for state_cls in required_state_classes:
|
|
|
+ redis_pipeline.get(_substate_key(token, state_cls))
|
|
|
+
|
|
|
+ for state_cls, redis_state in zip(
|
|
|
+ required_state_classes,
|
|
|
+ await redis_pipeline.execute(),
|
|
|
+ strict=False,
|
|
|
+ ):
|
|
|
+ state = None
|
|
|
+
|
|
|
+ if redis_state is not None:
|
|
|
+ # Deserialize the substate.
|
|
|
+ with contextlib.suppress(StateSchemaMismatchError):
|
|
|
+ state = BaseState._deserialize(data=redis_state)
|
|
|
+ if state is None:
|
|
|
+ # Key didn't exist or schema mismatch so create a new instance for this token.
|
|
|
+ state = state_cls(
|
|
|
+ init_substates=False,
|
|
|
+ _reflex_internal_init=True,
|
|
|
+ )
|
|
|
+ flat_state_tree[state.get_full_name()] = state
|
|
|
+ if state.get_parent_state() is not None:
|
|
|
+ parent_state_name, _dot, state_name = state.get_full_name().rpartition(
|
|
|
+ "."
|
|
|
+ )
|
|
|
+ parent_state = flat_state_tree.get(parent_state_name)
|
|
|
+ if parent_state is None:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Parent state for {state.get_full_name()} was not found "
|
|
|
+ "in the state tree, but should have already been fetched. "
|
|
|
+ "This is a bug",
|
|
|
+ )
|
|
|
+ parent_state.substates[state_name] = state
|
|
|
+ state.parent_state = parent_state
|
|
|
|
|
|
# To retain compatibility with previous implementation, by default, we return
|
|
|
- # the top-level state by chasing `parent_state` pointers up the tree.
|
|
|
+ # the top-level state which should always be fetched or already cached.
|
|
|
if top_level:
|
|
|
- return state._get_root_state()
|
|
|
- return state
|
|
|
+ return flat_state_tree[self.state.get_full_name()]
|
|
|
+ return flat_state_tree[state_cls.get_full_name()]
|
|
|
|
|
|
@override
|
|
|
async def set_state(
|
|
@@ -4154,12 +4077,19 @@ def reload_state_module(
|
|
|
state: Recursive argument for the state class to reload.
|
|
|
|
|
|
"""
|
|
|
+ # Clean out all potentially dirty states of reloaded modules.
|
|
|
+ for pd_state in tuple(state._potentially_dirty_states):
|
|
|
+ with contextlib.suppress(ValueError):
|
|
|
+ if (
|
|
|
+ state.get_root_state().get_class_substate(pd_state).__module__ == module
|
|
|
+ and module is not None
|
|
|
+ ):
|
|
|
+ state._potentially_dirty_states.remove(pd_state)
|
|
|
for subclass in tuple(state.class_subclasses):
|
|
|
reload_state_module(module=module, state=subclass)
|
|
|
if subclass.__module__ == module and module is not None:
|
|
|
state.class_subclasses.remove(subclass)
|
|
|
state._always_dirty_substates.discard(subclass.get_name())
|
|
|
- state._computed_var_dependencies = defaultdict(set)
|
|
|
- state._substate_var_dependencies = defaultdict(set)
|
|
|
+ state._var_dependencies = {}
|
|
|
state._init_var_dependency_dicts()
|
|
|
state.get_class_substate.cache_clear()
|