|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
import functools
|
|
|
+import inspect
|
|
|
import traceback
|
|
|
from abc import ABC
|
|
|
from typing import (
|
|
@@ -14,6 +15,7 @@ from typing import (
|
|
|
Optional,
|
|
|
Sequence,
|
|
|
Set,
|
|
|
+ Tuple,
|
|
|
Type,
|
|
|
Union,
|
|
|
)
|
|
@@ -51,6 +53,9 @@ class State(Base, ABC):
|
|
|
# Backend vars inherited
|
|
|
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
|
|
|
|
|
+ # Mapping of var name to set of computed variables that depend on it
|
|
|
+ computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
|
|
+
|
|
|
# The event handlers.
|
|
|
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
|
|
|
|
@@ -171,6 +176,7 @@ class State(Base, ABC):
|
|
|
**cls.base_vars,
|
|
|
**cls.computed_vars,
|
|
|
}
|
|
|
+ cls.computed_var_dependencies = {}
|
|
|
|
|
|
# Setup the base vars at the class level.
|
|
|
for prop in cls.base_vars.values():
|
|
@@ -472,12 +478,28 @@ class State(Base, ABC):
|
|
|
|
|
|
If the var is inherited, get the var from the parent state.
|
|
|
|
|
|
+ If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
|
|
|
+
|
|
|
Args:
|
|
|
name: The name of the var.
|
|
|
|
|
|
Returns:
|
|
|
The value of the var.
|
|
|
"""
|
|
|
+ vars = {
|
|
|
+ **super().__getattribute__("vars"),
|
|
|
+ **super().__getattribute__("backend_vars"),
|
|
|
+ }
|
|
|
+ if name in vars:
|
|
|
+ parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
|
|
|
+ if parent_frame is not None:
|
|
|
+ computed_vars = super().__getattribute__("computed_vars")
|
|
|
+ requesting_attribute_name = parent_frame_locals.get("name")
|
|
|
+ if requesting_attribute_name in computed_vars:
|
|
|
+ # Keep track of any ComputedVar that depends on this Var
|
|
|
+ super().__getattribute__("computed_var_dependencies").setdefault(
|
|
|
+ name, set()
|
|
|
+ ).add(requesting_attribute_name)
|
|
|
inherited_vars = {
|
|
|
**super().__getattribute__("inherited_vars"),
|
|
|
**super().__getattribute__("inherited_backend_vars"),
|
|
@@ -505,6 +527,7 @@ class State(Base, ABC):
|
|
|
|
|
|
if types.is_backend_variable(name):
|
|
|
self.backend_vars.__setitem__(name, value)
|
|
|
+ self.dirty_vars.add(name)
|
|
|
self.mark_dirty()
|
|
|
return
|
|
|
|
|
@@ -622,6 +645,28 @@ class State(Base, ABC):
|
|
|
# Return the state update.
|
|
|
return StateUpdate(delta=delta, events=events)
|
|
|
|
|
|
+ def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
|
|
|
+ """Get ComputedVars that need to be recomputed based on dirty_vars.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Set of computed vars to include in the delta.
|
|
|
+ """
|
|
|
+ dirty_computed_vars = set(
|
|
|
+ cvar
|
|
|
+ for dirty_var in from_vars or self.dirty_vars
|
|
|
+ for cvar in self.computed_vars
|
|
|
+ if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
|
|
+ )
|
|
|
+ if dirty_computed_vars:
|
|
|
+ # recursive call to catch computed vars that depend on computed vars
|
|
|
+ return dirty_computed_vars | self._dirty_computed_vars(
|
|
|
+ from_vars=dirty_computed_vars
|
|
|
+ )
|
|
|
+ return dirty_computed_vars
|
|
|
+
|
|
|
def get_delta(self) -> Delta:
|
|
|
"""Get the delta for the state.
|
|
|
|
|
@@ -630,10 +675,11 @@ class State(Base, ABC):
|
|
|
"""
|
|
|
delta = {}
|
|
|
|
|
|
- # Return the dirty vars, as well as all computed vars.
|
|
|
+ # Return the dirty vars, as well as computed vars depending on dirty vars.
|
|
|
subdelta = {
|
|
|
prop: getattr(self, prop)
|
|
|
- for prop in self.dirty_vars | self.computed_vars.keys()
|
|
|
+ for prop in self.dirty_vars | self._dirty_computed_vars()
|
|
|
+ if not types.is_backend_variable(prop)
|
|
|
}
|
|
|
if len(subdelta) > 0:
|
|
|
delta[self.get_full_name()] = subdelta
|
|
@@ -803,3 +849,24 @@ def _convert_mutable_datatypes(
|
|
|
field_value, reassign_field=reassign_field, field_name=field_name
|
|
|
)
|
|
|
return field_value
|
|
|
+
|
|
|
+
|
|
|
+def _get_previous_recursive_frame_info() -> (
|
|
|
+ Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
|
|
|
+):
|
|
|
+ """Find the previous frame of the same function that calls this helper.
|
|
|
+
|
|
|
+ For example, if this function is called from `State.__getattribute__`
|
|
|
+ (parent frame), then the returned frame will be the next earliest call
|
|
|
+ of the same function.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple of (frame_info, local_vars)
|
|
|
+
|
|
|
+ If no previous recursive frame is found up the stack, the frame info will be None.
|
|
|
+ """
|
|
|
+ _this_frame, parent_frame, *prev_frames = inspect.stack()
|
|
|
+ for frame in prev_frames:
|
|
|
+ if frame.frame.f_code == parent_frame.frame.f_code:
|
|
|
+ return frame, frame.frame.f_locals
|
|
|
+ return None, {}
|