Browse Source

Speed up computed var dependency tracking (#864)

Nikhil Rao 2 years ago
parent
commit
f019e0e55a
3 changed files with 81 additions and 65 deletions
  1. 1 1
      pynecone/.templates/web/pynecone.json
  2. 63 57
      pynecone/state.py
  3. 17 7
      tests/test_state.py

+ 1 - 1
pynecone/.templates/web/pynecone.json

@@ -1,3 +1,3 @@
 {
-    "version": "0.1.21"
+    "version": "0.1.25"
 }

+ 63 - 57
pynecone/state.py

@@ -3,9 +3,9 @@ from __future__ import annotations
 
 import asyncio
 import functools
-import inspect
 import traceback
 from abc import ABC
+from collections import defaultdict
 from typing import (
     Any,
     Callable,
@@ -15,7 +15,6 @@ from typing import (
     Optional,
     Sequence,
     Set,
-    Tuple,
     Type,
     Union,
 )
@@ -54,9 +53,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # Backend vars inherited
     inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
 
-    # Mapping of var name to set of computed variables that depend on it
-    computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
-
     # The event handlers.
     event_handlers: ClassVar[Dict[str, EventHandler]] = {}
 
@@ -75,18 +71,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # The routing path that triggered the state
     router_data: Dict[str, Any] = {}
 
-    def __init__(self, *args, **kwargs):
+    # Mapping of var name to set of computed variables that depend on it
+    computed_var_dependencies: Dict[str, Set[str]] = {}
+
+    # Whether to track accessed vars.
+    track_vars: bool = False
+
+    # The current set of accessed vars during tracking.
+    tracked_vars: Set[str] = set()
+
+    def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
         """Initialize the state.
 
         Args:
             *args: The args to pass to the Pydantic init method.
+            parent_state: The parent state.
             **kwargs: The kwargs to pass to the Pydantic init method.
         """
+        kwargs["parent_state"] = parent_state
         super().__init__(*args, **kwargs)
 
         # Setup the substates.
         for substate in self.get_substates():
-            self.substates[substate.get_name()] = substate().set(parent_state=self)
+            self.substates[substate.get_name()] = substate(parent_state=self)
 
         # Convert the event handlers to functions.
         for name, event_handler in self.event_handlers.items():
@@ -95,6 +102,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         # Initialize the mutable fields.
         self._init_mutable_fields()
 
+        # Initialize computed vars dependencies.
+        self.computed_var_dependencies = defaultdict(set)
+        for cvar in self.computed_vars:
+            self.tracked_vars = set()
+
+            # Enable tracking and get the computed var.
+            self.track_vars = True
+            self.__getattribute__(cvar)
+            self.track_vars = False
+
+            # Add the dependencies.
+            for var in self.tracked_vars:
+                self.computed_var_dependencies[var].add(cvar)
+
     def _init_mutable_fields(self):
         """Initialize mutable fields.
 
@@ -160,17 +181,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
 
         # Set the base and computed vars.
-        skip_vars = set(cls.inherited_vars) | {
-            "parent_state",
-            "substates",
-            "dirty_vars",
-            "dirty_substates",
-            "router_data",
-        }
         cls.base_vars = {
             f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
             for f in cls.get_fields().values()
-            if f.name not in skip_vars
+            if f.name not in cls.get_skip_vars()
         }
         cls.computed_vars = {
             v.name: v.set_state(cls)
@@ -202,6 +216,24 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             cls.event_handlers[name] = handler
             setattr(cls, name, handler)
 
+    @classmethod
+    def get_skip_vars(cls) -> Set[str]:
+        """Get the vars to skip when serializing.
+
+        Returns:
+            The vars to skip when serializing.
+        """
+        return set(cls.inherited_vars) | {
+            "parent_state",
+            "substates",
+            "dirty_vars",
+            "dirty_substates",
+            "router_data",
+            "computed_var_dependencies",
+            "track_vars",
+            "tracked_vars",
+        }
+
     @classmethod
     @functools.lru_cache()
     def get_parent_state(cls) -> Optional[Type[State]]:
@@ -481,20 +513,21 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The value of the var.
         """
-        vars = {
-            **super().__getattribute__("vars"),
-            **super().__getattribute__("backend_vars"),
-        }
-        if name in vars:
-            parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
-            if parent_frame is not None:
-                computed_vars = super().__getattribute__("computed_vars")
-                requesting_attribute_name = parent_frame_locals.get("name")
-                if requesting_attribute_name in computed_vars:
-                    # Keep track of any ComputedVar that depends on this Var
-                    super().__getattribute__("computed_var_dependencies").setdefault(
-                        name, set()
-                    ).add(requesting_attribute_name)
+        # If the state hasn't been initialized yet, return the default value.
+        if not super().__getattribute__("__dict__"):
+            return super().__getattribute__(name)
+
+        # Check if tracking is enabled.
+        if super().__getattribute__("track_vars"):
+            # Get the non-computed vars.
+            all_vars = {
+                **super().__getattribute__("vars"),
+                **super().__getattribute__("backend_vars"),
+            }
+            # Add the var to the tracked vars.
+            if name in all_vars:
+                super().__getattribute__("tracked_vars").add(name)
+
         inherited_vars = {
             **super().__getattribute__("inherited_vars"),
             **super().__getattribute__("inherited_backend_vars"),
@@ -649,18 +682,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             Set of computed vars to include in the delta.
         """
-        dirty_computed_vars = set(
+        return set(
             cvar
             for dirty_var in from_vars or self.dirty_vars
             for cvar in self.computed_vars
             if cvar in self.computed_var_dependencies.get(dirty_var, set())
         )
-        if dirty_computed_vars:
-            # recursive call to catch computed vars that depend on computed vars
-            return dirty_computed_vars | self._dirty_computed_vars(
-                from_vars=dirty_computed_vars
-            )
-        return dirty_computed_vars
 
     def get_delta(self) -> Delta:
         """Get the delta for the state.
@@ -844,24 +871,3 @@ def _convert_mutable_datatypes(
             field_value, reassign_field=reassign_field, field_name=field_name
         )
     return field_value
-
-
-def _get_previous_recursive_frame_info() -> (
-    Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
-):
-    """Find the previous frame of the same function that calls this helper.
-
-    For example, if this function is called from `State.__getattribute__`
-    (parent frame), then the returned frame will be the next earliest call
-    of the same function.
-
-    Returns:
-        Tuple of (frame_info, local_vars)
-
-    If no previous recursive frame is found up the stack, the frame info will be None.
-    """
-    _this_frame, parent_frame, *prev_frames = inspect.stack()
-    for frame in prev_frames:
-        if frame.frame.f_code == parent_frame.frame.f_code:
-            return frame, frame.frame.f_locals
-    return None, {}

+ 17 - 7
tests/test_state.py

@@ -155,13 +155,7 @@ def test_base_class_vars(test_state):
     cls = type(test_state)
 
     for field in fields:
-        if field in (
-            "parent_state",
-            "substates",
-            "dirty_vars",
-            "dirty_substates",
-            "router_data",
-        ):
+        if field in test_state.get_skip_vars():
             continue
         prop = getattr(cls, field)
         assert isinstance(prop, BaseVar)
@@ -819,3 +813,19 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
     assert interdependent_state.get_delta() == {
         interdependent_state.get_full_name(): {"v2x2": 4},
     }
+
+
+def test_child_state():
+    class MainState(State):
+        v: int = 2
+
+    class ChildState(MainState):
+        @ComputedVar
+        def rendered_var(self):
+            return self.v
+
+    ms = MainState()
+    cs = ms.substates[ChildState.get_name()]
+    assert ms.v == 2
+    assert cs.v == 2
+    assert cs.rendered_var == 2