|
@@ -74,12 +74,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# Mapping of var name to set of computed variables that depend on it
|
|
|
computed_var_dependencies: Dict[str, Set[str]] = {}
|
|
|
|
|
|
- # Whether to track accessed vars.
|
|
|
- track_vars: bool = False
|
|
|
-
|
|
|
- # The current set of accessed vars during tracking.
|
|
|
- tracked_vars: Set[str] = set()
|
|
|
-
|
|
|
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
|
|
|
"""Initialize the state.
|
|
|
|
|
@@ -102,22 +96,15 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
|
|
setattr(self, name, fn)
|
|
|
|
|
|
- # Initialize the mutable fields.
|
|
|
- self._init_mutable_fields()
|
|
|
-
|
|
|
# Initialize computed vars dependencies.
|
|
|
self.computed_var_dependencies = defaultdict(set)
|
|
|
- for cvar in self.computed_vars:
|
|
|
- self.tracked_vars = set()
|
|
|
-
|
|
|
- # Enable tracking and get the computed var.
|
|
|
- self.track_vars = True
|
|
|
- self.__getattribute__(cvar)
|
|
|
- self.track_vars = False
|
|
|
-
|
|
|
+ for cvar_name, cvar in self.computed_vars.items():
|
|
|
# Add the dependencies.
|
|
|
- for var in self.tracked_vars:
|
|
|
- self.computed_var_dependencies[var].add(cvar)
|
|
|
+ for var in cvar.deps():
|
|
|
+ self.computed_var_dependencies[var].add(cvar_name)
|
|
|
+
|
|
|
+ # Initialize the mutable fields.
|
|
|
+ self._init_mutable_fields()
|
|
|
|
|
|
def _init_mutable_fields(self):
|
|
|
"""Initialize mutable fields.
|
|
@@ -199,7 +186,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
**cls.base_vars,
|
|
|
**cls.computed_vars,
|
|
|
}
|
|
|
- cls.computed_var_dependencies = {}
|
|
|
cls.event_handlers = {}
|
|
|
|
|
|
# Setup the base vars at the class level.
|
|
@@ -233,8 +219,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
"dirty_substates",
|
|
|
"router_data",
|
|
|
"computed_var_dependencies",
|
|
|
- "track_vars",
|
|
|
- "tracked_vars",
|
|
|
}
|
|
|
|
|
|
@classmethod
|
|
@@ -508,8 +492,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
|
|
|
If the var is inherited, get the var from the parent state.
|
|
|
|
|
|
- If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
|
|
|
-
|
|
|
Args:
|
|
|
name: The name of the var.
|
|
|
|
|
@@ -520,17 +502,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
if not super().__getattribute__("__dict__"):
|
|
|
return super().__getattribute__(name)
|
|
|
|
|
|
- # Check if tracking is enabled.
|
|
|
- if super().__getattribute__("track_vars"):
|
|
|
- # Get the non-computed vars.
|
|
|
- all_vars = {
|
|
|
- **super().__getattribute__("vars"),
|
|
|
- **super().__getattribute__("backend_vars"),
|
|
|
- }
|
|
|
- # Add the var to the tracked vars.
|
|
|
- if name in all_vars:
|
|
|
- super().__getattribute__("tracked_vars").add(name)
|
|
|
-
|
|
|
inherited_vars = {
|
|
|
**super().__getattribute__("inherited_vars"),
|
|
|
**super().__getattribute__("inherited_backend_vars"),
|
|
@@ -676,55 +647,58 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# Return the state update.
|
|
|
return StateUpdate(delta=delta, events=events)
|
|
|
|
|
|
- def _dirty_computed_vars(
|
|
|
- self, from_vars: Optional[Set[str]] = None, check: bool = False
|
|
|
- ) -> Set[str]:
|
|
|
- """Get ComputedVars that need to be recomputed based on dirty_vars.
|
|
|
+ def _mark_dirty_computed_vars(self) -> None:
|
|
|
+ """Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
|
|
+ dirty_vars = self.dirty_vars
|
|
|
+ while dirty_vars:
|
|
|
+ calc_vars, dirty_vars = dirty_vars, set()
|
|
|
+ for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
|
|
+ self.dirty_vars.add(cvar)
|
|
|
+ dirty_vars.add(cvar)
|
|
|
+ actual_var = self.computed_vars.get(cvar)
|
|
|
+ if actual_var:
|
|
|
+ actual_var.mark_dirty(instance=self)
|
|
|
+
|
|
|
+ def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
|
|
|
+ """Determine ComputedVars that need to be recalculated based on the given vars.
|
|
|
|
|
|
Args:
|
|
|
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
|
|
|
- check: Whether to perform the check.
|
|
|
|
|
|
Returns:
|
|
|
Set of computed vars to include in the delta.
|
|
|
"""
|
|
|
- # If checking is disabled, return all computed vars.
|
|
|
- if not check:
|
|
|
- return set(self.computed_vars)
|
|
|
-
|
|
|
- # Return only the computed vars that depend on the dirty vars.
|
|
|
return set(
|
|
|
cvar
|
|
|
for dirty_var in from_vars or self.dirty_vars
|
|
|
- for cvar in self.computed_vars
|
|
|
- if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
|
|
+ for cvar in self.computed_var_dependencies[dirty_var]
|
|
|
)
|
|
|
|
|
|
- def get_delta(self, check: bool = False) -> Delta:
|
|
|
+ def get_delta(self) -> Delta:
|
|
|
"""Get the delta for the state.
|
|
|
|
|
|
- Args:
|
|
|
- check: Whether to check for dirty computed vars.
|
|
|
-
|
|
|
Returns:
|
|
|
The delta for the state.
|
|
|
"""
|
|
|
delta = {}
|
|
|
|
|
|
- # Return the dirty vars, as well as computed vars depending on dirty vars.
|
|
|
+ # Recursively find the substate deltas.
|
|
|
+ substates = self.substates
|
|
|
+ for substate in self.dirty_substates:
|
|
|
+ delta.update(substates[substate].get_delta())
|
|
|
+
|
|
|
+ # Return the dirty vars and dependent computed vars
|
|
|
+ delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
|
|
+ self._dirty_computed_vars()
|
|
|
+ )
|
|
|
subdelta = {
|
|
|
prop: getattr(self, prop)
|
|
|
- for prop in self.dirty_vars | self._dirty_computed_vars(check=check)
|
|
|
+ for prop in delta_vars
|
|
|
if not types.is_backend_variable(prop)
|
|
|
}
|
|
|
if len(subdelta) > 0:
|
|
|
delta[self.get_full_name()] = subdelta
|
|
|
|
|
|
- # Recursively find the substate deltas.
|
|
|
- substates = self.substates
|
|
|
- for substate in self.dirty_substates:
|
|
|
- delta.update(substates[substate].get_delta())
|
|
|
-
|
|
|
# Format the delta.
|
|
|
delta = format.format_state(delta)
|
|
|
|
|
@@ -737,6 +711,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
self.parent_state.dirty_substates.add(self.get_name())
|
|
|
self.parent_state.mark_dirty()
|
|
|
|
|
|
+ # have to mark computed vars dirty to allow access to newly computed
|
|
|
+ # values within the same ComputedVar function
|
|
|
+ self._mark_dirty_computed_vars()
|
|
|
+
|
|
|
def clean(self):
|
|
|
"""Reset the dirty vars."""
|
|
|
# Recursively clean the substates.
|