|
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
import asyncio
|
|
import functools
|
|
import functools
|
|
-import inspect
|
|
|
|
import traceback
|
|
import traceback
|
|
from abc import ABC
|
|
from abc import ABC
|
|
|
|
+from collections import defaultdict
|
|
from typing import (
|
|
from typing import (
|
|
Any,
|
|
Any,
|
|
Callable,
|
|
Callable,
|
|
@@ -15,7 +15,6 @@ from typing import (
|
|
Optional,
|
|
Optional,
|
|
Sequence,
|
|
Sequence,
|
|
Set,
|
|
Set,
|
|
- Tuple,
|
|
|
|
Type,
|
|
Type,
|
|
Union,
|
|
Union,
|
|
)
|
|
)
|
|
@@ -54,9 +53,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
# Backend vars inherited
|
|
# Backend vars inherited
|
|
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
|
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
|
|
|
|
|
- # Mapping of var name to set of computed variables that depend on it
|
|
|
|
- computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
|
|
|
-
|
|
|
|
# The event handlers.
|
|
# The event handlers.
|
|
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
|
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
|
|
|
|
|
@@ -75,18 +71,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
# The routing path that triggered the state
|
|
# The routing path that triggered the state
|
|
router_data: Dict[str, Any] = {}
|
|
router_data: Dict[str, Any] = {}
|
|
|
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
|
|
|
+ # Mapping of var name to set of computed variables that depend on it
|
|
|
|
+ computed_var_dependencies: Dict[str, Set[str]] = {}
|
|
|
|
+
|
|
|
|
+ # Whether to track accessed vars.
|
|
|
|
+ track_vars: bool = False
|
|
|
|
+
|
|
|
|
+ # The current set of accessed vars during tracking.
|
|
|
|
+ tracked_vars: Set[str] = set()
|
|
|
|
+
|
|
|
|
+ def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
|
|
"""Initialize the state.
|
|
"""Initialize the state.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
*args: The args to pass to the Pydantic init method.
|
|
*args: The args to pass to the Pydantic init method.
|
|
|
|
+ parent_state: The parent state.
|
|
**kwargs: The kwargs to pass to the Pydantic init method.
|
|
**kwargs: The kwargs to pass to the Pydantic init method.
|
|
"""
|
|
"""
|
|
|
|
+ kwargs["parent_state"] = parent_state
|
|
super().__init__(*args, **kwargs)
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
# Setup the substates.
|
|
# Setup the substates.
|
|
for substate in self.get_substates():
|
|
for substate in self.get_substates():
|
|
- self.substates[substate.get_name()] = substate().set(parent_state=self)
|
|
|
|
|
|
+ self.substates[substate.get_name()] = substate(parent_state=self)
|
|
|
|
|
|
# Convert the event handlers to functions.
|
|
# Convert the event handlers to functions.
|
|
for name, event_handler in self.event_handlers.items():
|
|
for name, event_handler in self.event_handlers.items():
|
|
@@ -95,6 +102,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
# Initialize the mutable fields.
|
|
# Initialize the mutable fields.
|
|
self._init_mutable_fields()
|
|
self._init_mutable_fields()
|
|
|
|
|
|
|
|
+ # Initialize computed vars dependencies.
|
|
|
|
+ self.computed_var_dependencies = defaultdict(set)
|
|
|
|
+ for cvar in self.computed_vars:
|
|
|
|
+ self.tracked_vars = set()
|
|
|
|
+
|
|
|
|
+ # Enable tracking and get the computed var.
|
|
|
|
+ self.track_vars = True
|
|
|
|
+ self.__getattribute__(cvar)
|
|
|
|
+ self.track_vars = False
|
|
|
|
+
|
|
|
|
+ # Add the dependencies.
|
|
|
|
+ for var in self.tracked_vars:
|
|
|
|
+ self.computed_var_dependencies[var].add(cvar)
|
|
|
|
+
|
|
def _init_mutable_fields(self):
|
|
def _init_mutable_fields(self):
|
|
"""Initialize mutable fields.
|
|
"""Initialize mutable fields.
|
|
|
|
|
|
@@ -160,17 +181,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
|
|
cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
|
|
|
|
|
|
# Set the base and computed vars.
|
|
# Set the base and computed vars.
|
|
- skip_vars = set(cls.inherited_vars) | {
|
|
|
|
- "parent_state",
|
|
|
|
- "substates",
|
|
|
|
- "dirty_vars",
|
|
|
|
- "dirty_substates",
|
|
|
|
- "router_data",
|
|
|
|
- }
|
|
|
|
cls.base_vars = {
|
|
cls.base_vars = {
|
|
f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
|
|
f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
|
|
for f in cls.get_fields().values()
|
|
for f in cls.get_fields().values()
|
|
- if f.name not in skip_vars
|
|
|
|
|
|
+ if f.name not in cls.get_skip_vars()
|
|
}
|
|
}
|
|
cls.computed_vars = {
|
|
cls.computed_vars = {
|
|
v.name: v.set_state(cls)
|
|
v.name: v.set_state(cls)
|
|
@@ -202,6 +216,24 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
cls.event_handlers[name] = handler
|
|
cls.event_handlers[name] = handler
|
|
setattr(cls, name, handler)
|
|
setattr(cls, name, handler)
|
|
|
|
|
|
|
|
+ @classmethod
|
|
|
|
+ def get_skip_vars(cls) -> Set[str]:
|
|
|
|
+ """Get the vars to skip when serializing.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The vars to skip when serializing.
|
|
|
|
+ """
|
|
|
|
+ return set(cls.inherited_vars) | {
|
|
|
|
+ "parent_state",
|
|
|
|
+ "substates",
|
|
|
|
+ "dirty_vars",
|
|
|
|
+ "dirty_substates",
|
|
|
|
+ "router_data",
|
|
|
|
+ "computed_var_dependencies",
|
|
|
|
+ "track_vars",
|
|
|
|
+ "tracked_vars",
|
|
|
|
+ }
|
|
|
|
+
|
|
@classmethod
|
|
@classmethod
|
|
@functools.lru_cache()
|
|
@functools.lru_cache()
|
|
def get_parent_state(cls) -> Optional[Type[State]]:
|
|
def get_parent_state(cls) -> Optional[Type[State]]:
|
|
@@ -481,20 +513,21 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
Returns:
|
|
Returns:
|
|
The value of the var.
|
|
The value of the var.
|
|
"""
|
|
"""
|
|
- vars = {
|
|
|
|
- **super().__getattribute__("vars"),
|
|
|
|
- **super().__getattribute__("backend_vars"),
|
|
|
|
- }
|
|
|
|
- if name in vars:
|
|
|
|
- parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
|
|
|
|
- if parent_frame is not None:
|
|
|
|
- computed_vars = super().__getattribute__("computed_vars")
|
|
|
|
- requesting_attribute_name = parent_frame_locals.get("name")
|
|
|
|
- if requesting_attribute_name in computed_vars:
|
|
|
|
- # Keep track of any ComputedVar that depends on this Var
|
|
|
|
- super().__getattribute__("computed_var_dependencies").setdefault(
|
|
|
|
- name, set()
|
|
|
|
- ).add(requesting_attribute_name)
|
|
|
|
|
|
+ # If the state hasn't been initialized yet, return the default value.
|
|
|
|
+ if not super().__getattribute__("__dict__"):
|
|
|
|
+ return super().__getattribute__(name)
|
|
|
|
+
|
|
|
|
+ # Check if tracking is enabled.
|
|
|
|
+ if super().__getattribute__("track_vars"):
|
|
|
|
+ # Get the non-computed vars.
|
|
|
|
+ all_vars = {
|
|
|
|
+ **super().__getattribute__("vars"),
|
|
|
|
+ **super().__getattribute__("backend_vars"),
|
|
|
|
+ }
|
|
|
|
+ # Add the var to the tracked vars.
|
|
|
|
+ if name in all_vars:
|
|
|
|
+ super().__getattribute__("tracked_vars").add(name)
|
|
|
|
+
|
|
inherited_vars = {
|
|
inherited_vars = {
|
|
**super().__getattribute__("inherited_vars"),
|
|
**super().__getattribute__("inherited_vars"),
|
|
**super().__getattribute__("inherited_backend_vars"),
|
|
**super().__getattribute__("inherited_backend_vars"),
|
|
@@ -649,18 +682,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|
Returns:
|
|
Returns:
|
|
Set of computed vars to include in the delta.
|
|
Set of computed vars to include in the delta.
|
|
"""
|
|
"""
|
|
- dirty_computed_vars = set(
|
|
|
|
|
|
+ return set(
|
|
cvar
|
|
cvar
|
|
for dirty_var in from_vars or self.dirty_vars
|
|
for dirty_var in from_vars or self.dirty_vars
|
|
for cvar in self.computed_vars
|
|
for cvar in self.computed_vars
|
|
if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
|
if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
|
)
|
|
)
|
|
- if dirty_computed_vars:
|
|
|
|
- # recursive call to catch computed vars that depend on computed vars
|
|
|
|
- return dirty_computed_vars | self._dirty_computed_vars(
|
|
|
|
- from_vars=dirty_computed_vars
|
|
|
|
- )
|
|
|
|
- return dirty_computed_vars
|
|
|
|
|
|
|
|
def get_delta(self) -> Delta:
|
|
def get_delta(self) -> Delta:
|
|
"""Get the delta for the state.
|
|
"""Get the delta for the state.
|
|
@@ -844,24 +871,3 @@ def _convert_mutable_datatypes(
|
|
field_value, reassign_field=reassign_field, field_name=field_name
|
|
field_value, reassign_field=reassign_field, field_name=field_name
|
|
)
|
|
)
|
|
return field_value
|
|
return field_value
|
|
-
|
|
|
|
-
|
|
|
|
-def _get_previous_recursive_frame_info() -> (
|
|
|
|
- Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
|
|
|
|
-):
|
|
|
|
- """Find the previous frame of the same function that calls this helper.
|
|
|
|
-
|
|
|
|
- For example, if this function is called from `State.__getattribute__`
|
|
|
|
- (parent frame), then the returned frame will be the next earliest call
|
|
|
|
- of the same function.
|
|
|
|
-
|
|
|
|
- Returns:
|
|
|
|
- Tuple of (frame_info, local_vars)
|
|
|
|
-
|
|
|
|
- If no previous recursive frame is found up the stack, the frame info will be None.
|
|
|
|
- """
|
|
|
|
- _this_frame, parent_frame, *prev_frames = inspect.stack()
|
|
|
|
- for frame in prev_frames:
|
|
|
|
- if frame.frame.f_code == parent_frame.frame.f_code:
|
|
|
|
- return frame, frame.frame.f_locals
|
|
|
|
- return None, {}
|
|
|