Browse Source

Only update ComputedVar when dependent vars change (#840)

Masen Furer 2 years ago
parent
commit
b4755b8123
2 changed files with 165 additions and 5 deletions
  1. 69 2
      pynecone/state.py
  2. 96 3
      tests/test_state.py

+ 69 - 2
pynecone/state.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 
 import asyncio
 import functools
+import inspect
 import traceback
 from abc import ABC
 from typing import (
@@ -14,6 +15,7 @@ from typing import (
     Optional,
     Sequence,
     Set,
+    Tuple,
     Type,
     Union,
 )
@@ -51,6 +53,9 @@ class State(Base, ABC):
     # Backend vars inherited
     inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
 
+    # Mapping of var name to set of computed variables that depend on it
+    computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
+
     # The event handlers.
     event_handlers: ClassVar[Dict[str, EventHandler]] = {}
 
@@ -171,6 +176,7 @@ class State(Base, ABC):
             **cls.base_vars,
             **cls.computed_vars,
         }
+        cls.computed_var_dependencies = {}
 
         # Setup the base vars at the class level.
         for prop in cls.base_vars.values():
@@ -472,12 +478,28 @@ class State(Base, ABC):
 
         If the var is inherited, get the var from the parent state.
 
+        If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
+
         Args:
             name: The name of the var.
 
         Returns:
             The value of the var.
         """
+        vars = {
+            **super().__getattribute__("vars"),
+            **super().__getattribute__("backend_vars"),
+        }
+        if name in vars:
+            parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
+            if parent_frame is not None:
+                computed_vars = super().__getattribute__("computed_vars")
+                requesting_attribute_name = parent_frame_locals.get("name")
+                if requesting_attribute_name in computed_vars:
+                    # Keep track of any ComputedVar that depends on this Var
+                    super().__getattribute__("computed_var_dependencies").setdefault(
+                        name, set()
+                    ).add(requesting_attribute_name)
         inherited_vars = {
             **super().__getattribute__("inherited_vars"),
             **super().__getattribute__("inherited_backend_vars"),
@@ -505,6 +527,7 @@ class State(Base, ABC):
 
         if types.is_backend_variable(name):
             self.backend_vars.__setitem__(name, value)
+            self.dirty_vars.add(name)
             self.mark_dirty()
             return
 
@@ -622,6 +645,28 @@ class State(Base, ABC):
         # Return the state update.
         return StateUpdate(delta=delta, events=events)
 
+    def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
+        """Get ComputedVars that need to be recomputed based on dirty_vars.
+
+        Args:
+            from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
+
+        Returns:
+            Set of computed vars to include in the delta.
+        """
+        dirty_computed_vars = set(
+            cvar
+            for dirty_var in from_vars or self.dirty_vars
+            for cvar in self.computed_vars
+            if cvar in self.computed_var_dependencies.get(dirty_var, set())
+        )
+        if dirty_computed_vars:
+            # recursive call to catch computed vars that depend on computed vars
+            return dirty_computed_vars | self._dirty_computed_vars(
+                from_vars=dirty_computed_vars
+            )
+        return dirty_computed_vars
+
     def get_delta(self) -> Delta:
         """Get the delta for the state.
 
@@ -630,10 +675,11 @@ class State(Base, ABC):
         """
         delta = {}
 
-        # Return the dirty vars, as well as all computed vars.
+        # Return the dirty vars, as well as computed vars depending on dirty vars.
         subdelta = {
             prop: getattr(self, prop)
-            for prop in self.dirty_vars | self.computed_vars.keys()
+            for prop in self.dirty_vars | self._dirty_computed_vars()
+            if not types.is_backend_variable(prop)
         }
         if len(subdelta) > 0:
             delta[self.get_full_name()] = subdelta
@@ -803,3 +849,24 @@ def _convert_mutable_datatypes(
             field_value, reassign_field=reassign_field, field_name=field_name
         )
     return field_value
+
+
+def _get_previous_recursive_frame_info() -> (
+    Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
+):
+    """Find the previous frame of the same function that calls this helper.
+
+    For example, if this function is called from `State.__getattribute__`
+    (parent frame), then the returned frame will be the next earliest call
+    of the same function.
+
+    Returns:
+        Tuple of (frame_info, local_vars)
+
+    If no previous recursive frame is found up the stack, the frame info will be None.
+    """
+    _this_frame, parent_frame, *prev_frames = inspect.stack()
+    for frame in prev_frames:
+        if frame.frame.f_code == parent_frame.frame.f_code:
+            return frame, frame.frame.f_locals
+    return None, {}

+ 96 - 3
tests/test_state.py

@@ -582,7 +582,7 @@ async def test_process_event_simple(test_state):
     assert test_state.num1 == 69
 
     # The delta should contain the changes, including computed vars.
-    assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
+    assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
     assert update.events == []
 
 
@@ -606,7 +606,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     assert child_state.count == 24
     assert update.delta == {
         "test_state.child_state": {"value": "HI", "count": 24},
-        "test_state": {"sum": 3.14, "upper": ""},
     }
     test_state.clean()
 
@@ -621,7 +620,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     assert grandchild_state.value2 == "new"
     assert update.delta == {
         "test_state.child_state.grandchild_state": {"value2": "new"},
-        "test_state": {"sum": 3.14, "upper": ""},
     }
 
 
@@ -724,3 +722,98 @@ def test_add_var(test_state):
     test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10})
     assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
     assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
+
+
+class InterdependentState(State):
+    """A state with 3 vars and 3 computed vars.
+
+    x: a variable that no computed var depends on
+    v1: a varable that one computed var directly depeneds on
+    _v2: a backend variable that one computed var directly depends on
+
+    v1x2: a computed var that depends on v1
+    v2x2: a computed var that depends on backend var _v2
+    v1x2x2: a computed var that depends on computed var v1x2
+    """
+
+    x: int = 0
+    v1: int = 0
+    _v2: int = 1
+
+    @ComputedVar
+    def v1x2(self) -> int:
+        """depends on var v1.
+
+        Returns:
+            Var v1 multiplied by 2
+        """
+        return self.v1 * 2
+
+    @ComputedVar
+    def v2x2(self) -> int:
+        """depends on backend var _v2.
+
+        Returns:
+            backend var _v2 multiplied by 2
+        """
+        return self._v2 * 2
+
+    @ComputedVar
+    def v1x2x2(self) -> int:
+        """depends on ComputedVar v1x2.
+
+        Returns:
+            ComputedVar v1x2 multiplied by 2
+        """
+        return self.v1x2 * 2
+
+
+@pytest.fixture
+def interdependent_state() -> State:
+    """A state with varying dependency between vars.
+
+    Returns:
+        instance of InterdependentState
+    """
+    s = InterdependentState()
+    s.dict()  # prime initial relationships by accessing all ComputedVars
+    return s
+
+
+def test_not_dirty_computed_var_from_var(interdependent_state):
+    """Set Var that no ComputedVar depends on, expect no recalculation.
+
+    Args:
+        interdependent_state: A state with varying Var dependencies.
+    """
+    interdependent_state.x = 5
+    assert interdependent_state.get_delta() == {
+        interdependent_state.get_full_name(): {"x": 5},
+    }
+
+
+def test_dirty_computed_var_from_var(interdependent_state):
+    """Set Var that ComputedVar depends on, expect recalculation.
+
+    The other ComputedVar depends on the changed ComputedVar and should also be
+    recalculated. No other ComputedVars should be recalculated.
+
+    Args:
+        interdependent_state: A state with varying Var dependencies.
+    """
+    interdependent_state.v1 = 1
+    assert interdependent_state.get_delta() == {
+        interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
+    }
+
+
+def test_dirty_computed_var_from_backend_var(interdependent_state):
+    """Set backend var that ComputedVar depends on, expect recalculation.
+
+    Args:
+        interdependent_state: A state with varying Var dependencies.
+    """
+    interdependent_state._v2 = 2
+    assert interdependent_state.get_delta() == {
+        interdependent_state.get_full_name(): {"v2x2": 4},
+    }