Browse Source

Cache ComputedVar (#917)

Masen Furer 2 years ago
parent
commit
c344a5c0d7
3 changed files with 152 additions and 66 deletions
  1. 36 58
      pynecone/state.py
  2. 83 0
      pynecone/var.py
  3. 33 8
      tests/test_state.py

+ 36 - 58
pynecone/state.py

@@ -74,12 +74,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # Mapping of var name to set of computed variables that depend on it
     computed_var_dependencies: Dict[str, Set[str]] = {}
 
-    # Whether to track accessed vars.
-    track_vars: bool = False
-
-    # The current set of accessed vars during tracking.
-    tracked_vars: Set[str] = set()
-
     def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
         """Initialize the state.
 
@@ -102,22 +96,15 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             fn.__qualname__ = event_handler.fn.__qualname__  # type: ignore
             setattr(self, name, fn)
 
-        # Initialize the mutable fields.
-        self._init_mutable_fields()
-
         # Initialize computed vars dependencies.
         self.computed_var_dependencies = defaultdict(set)
-        for cvar in self.computed_vars:
-            self.tracked_vars = set()
-
-            # Enable tracking and get the computed var.
-            self.track_vars = True
-            self.__getattribute__(cvar)
-            self.track_vars = False
-
+        for cvar_name, cvar in self.computed_vars.items():
             # Add the dependencies.
-            for var in self.tracked_vars:
-                self.computed_var_dependencies[var].add(cvar)
+            for var in cvar.deps():
+                self.computed_var_dependencies[var].add(cvar_name)
+
+        # Initialize the mutable fields.
+        self._init_mutable_fields()
 
     def _init_mutable_fields(self):
         """Initialize mutable fields.
@@ -199,7 +186,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             **cls.base_vars,
             **cls.computed_vars,
         }
-        cls.computed_var_dependencies = {}
         cls.event_handlers = {}
 
         # Setup the base vars at the class level.
@@ -233,8 +219,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             "dirty_substates",
             "router_data",
             "computed_var_dependencies",
-            "track_vars",
-            "tracked_vars",
         }
 
     @classmethod
@@ -508,8 +492,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
         If the var is inherited, get the var from the parent state.
 
-        If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
-
         Args:
             name: The name of the var.
 
@@ -520,17 +502,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         if not super().__getattribute__("__dict__"):
             return super().__getattribute__(name)
 
-        # Check if tracking is enabled.
-        if super().__getattribute__("track_vars"):
-            # Get the non-computed vars.
-            all_vars = {
-                **super().__getattribute__("vars"),
-                **super().__getattribute__("backend_vars"),
-            }
-            # Add the var to the tracked vars.
-            if name in all_vars:
-                super().__getattribute__("tracked_vars").add(name)
-
         inherited_vars = {
             **super().__getattribute__("inherited_vars"),
             **super().__getattribute__("inherited_backend_vars"),
@@ -676,55 +647,58 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         # Return the state update.
         return StateUpdate(delta=delta, events=events)
 
-    def _dirty_computed_vars(
-        self, from_vars: Optional[Set[str]] = None, check: bool = False
-    ) -> Set[str]:
-        """Get ComputedVars that need to be recomputed based on dirty_vars.
+    def _mark_dirty_computed_vars(self) -> None:
+        """Mark ComputedVars that need to be recalculated based on dirty_vars."""
+        dirty_vars = self.dirty_vars
+        while dirty_vars:
+            calc_vars, dirty_vars = dirty_vars, set()
+            for cvar in self._dirty_computed_vars(from_vars=calc_vars):
+                self.dirty_vars.add(cvar)
+                dirty_vars.add(cvar)
+                actual_var = self.computed_vars.get(cvar)
+                if actual_var:
+                    actual_var.mark_dirty(instance=self)
+
+    def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
+        """Determine ComputedVars that need to be recalculated based on the given vars.
 
         Args:
             from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
-            check: Whether to perform the check.
 
         Returns:
             Set of computed vars to include in the delta.
         """
-        # If checking is disabled, return all computed vars.
-        if not check:
-            return set(self.computed_vars)
-
-        # Return only the computed vars that depend on the dirty vars.
         return set(
             cvar
             for dirty_var in from_vars or self.dirty_vars
-            for cvar in self.computed_vars
-            if cvar in self.computed_var_dependencies.get(dirty_var, set())
+            for cvar in self.computed_var_dependencies[dirty_var]
         )
 
-    def get_delta(self, check: bool = False) -> Delta:
+    def get_delta(self) -> Delta:
         """Get the delta for the state.
 
-        Args:
-            check: Whether to check for dirty computed vars.
-
         Returns:
             The delta for the state.
         """
         delta = {}
 
-        # Return the dirty vars, as well as computed vars depending on dirty vars.
+        # Recursively find the substate deltas.
+        substates = self.substates
+        for substate in self.dirty_substates:
+            delta.update(substates[substate].get_delta())
+
+        # Return the dirty vars and dependent computed vars
+        delta_vars = self.dirty_vars.intersection(self.base_vars).union(
+            self._dirty_computed_vars()
+        )
         subdelta = {
             prop: getattr(self, prop)
-            for prop in self.dirty_vars | self._dirty_computed_vars(check=check)
+            for prop in delta_vars
             if not types.is_backend_variable(prop)
         }
         if len(subdelta) > 0:
             delta[self.get_full_name()] = subdelta
 
-        # Recursively find the substate deltas.
-        substates = self.substates
-        for substate in self.dirty_substates:
-            delta.update(substates[substate].get_delta())
-
         # Format the delta.
         delta = format.format_state(delta)
 
@@ -737,6 +711,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             self.parent_state.dirty_substates.add(self.get_name())
             self.parent_state.mark_dirty()
 
+        # have to mark computed vars dirty to allow access to newly computed
+        # values within the same ComputedVar function
+        self._mark_dirty_computed_vars()
+
     def clean(self):
         """Reset the dirty vars."""
         # Recursively clean the substates.

+ 83 - 0
pynecone/var.py

@@ -1,10 +1,13 @@
 """Define a state var."""
 from __future__ import annotations
 
+import contextlib
+import dis
 import json
 import random
 import string
 from abc import ABC
+from types import FunctionType
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -12,9 +15,11 @@ from typing import (
     Dict,
     List,
     Optional,
+    Set,
     Type,
     Union,
     _GenericAlias,  # type: ignore
+    cast,
     get_type_hints,
 )
 
@@ -801,6 +806,84 @@ class ComputedVar(property, Var):
         assert self.fget is not None, "Var must have a getter."
         return self.fget.__name__
 
+    @property
+    def cache_attr(self) -> str:
+        """Get the attribute used to cache the value on the instance.
+
+        Returns:
+            An attribute name.
+        """
+        return f"__cached_{self.name}"
+
+    def __get__(self, instance, owner):
+        """Get the ComputedVar value.
+
+        If the value is already cached on the instance, return the cached value.
+
+        If this ComputedVar doesn't know what type of object it is attached to, then save
+        a reference as self.__objclass__.
+
+        Args:
+            instance: the instance of the class accessing this computed var.
+            owner: the class that this descriptor is attached to.
+
+        Returns:
+            The value of the var for the given instance.
+        """
+        if not hasattr(self, "__objclass__"):
+            self.__objclass__ = owner
+
+        if instance is None:
+            return super().__get__(instance, owner)
+
+        # handle caching
+        if not hasattr(instance, self.cache_attr):
+            setattr(instance, self.cache_attr, super().__get__(instance, owner))
+        return getattr(instance, self.cache_attr)
+
+    def deps(self, obj: Optional[FunctionType] = None) -> Set[str]:
+        """Determine var dependencies of this ComputedVar.
+
+        Save references to attributes accessed on "self".  Recursively called
+        when the function makes a method call on "self".
+
+        Args:
+            obj: the object to disassemble (defaults to the fget function).
+
+        Returns:
+            A set of variable names accessed by the given obj.
+        """
+        d = set()
+        if obj is None:
+            if self.fget is not None:
+                obj = cast(FunctionType, self.fget)
+            else:
+                return set()
+        if not obj.__code__.co_varnames:
+            # cannot reference self if method takes no args
+            return set()
+        self_name = obj.__code__.co_varnames[0]
+        self_is_top_of_stack = False
+        for instruction in dis.get_instructions(obj):
+            if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:
+                self_is_top_of_stack = True
+                continue
+            if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
+                d.add(instruction.argval)
+            elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
+                d.update(self.deps(obj=getattr(self.__objclass__, instruction.argval)))
+            self_is_top_of_stack = False
+        return d
+
+    def mark_dirty(self, instance) -> None:
+        """Mark this ComputedVar as dirty.
+
+        Args:
+            instance: the state instance that needs to recompute the value.
+        """
+        with contextlib.suppress(AttributeError):
+            delattr(instance, self.cache_attr)
+
     @property
     def type_(self):
         """Get the type of the var.

+ 33 - 8
tests/test_state.py

@@ -485,11 +485,11 @@ def test_set_dirty_var(test_state):
 
     # Setting a var should mark it as dirty.
     test_state.num1 = 1
-    assert test_state.dirty_vars == {"num1"}
+    assert test_state.dirty_vars == {"num1", "sum"}
 
     # Setting another var should mark it as dirty.
     test_state.num2 = 2
-    assert test_state.dirty_vars == {"num1", "num2"}
+    assert test_state.dirty_vars == {"num1", "num2", "sum"}
 
     # Cleaning the state should remove all dirty vars.
     test_state.clean()
@@ -578,7 +578,7 @@ async def test_process_event_simple(test_state):
     assert test_state.num1 == 69
 
     # The delta should contain the changes, including computed vars.
-    assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
+    assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
     assert update.events == []
 
 
@@ -601,7 +601,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     assert child_state.value == "HI"
     assert child_state.count == 24
     assert update.delta == {
-        "test_state": {"sum": 3.14, "upper": ""},
         "test_state.child_state": {"value": "HI", "count": 24},
     }
     test_state.clean()
@@ -616,7 +615,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
     update = await test_state._process(event)
     assert grandchild_state.value2 == "new"
     assert update.delta == {
-        "test_state": {"sum": 3.14, "upper": ""},
         "test_state.child_state.grandchild_state": {"value2": "new"},
     }
 
@@ -791,7 +789,7 @@ def test_not_dirty_computed_var_from_var(interdependent_state):
         interdependent_state: A state with varying Var dependencies.
     """
     interdependent_state.x = 5
-    assert interdependent_state.get_delta(check=True) == {
+    assert interdependent_state.get_delta() == {
         interdependent_state.get_full_name(): {"x": 5},
     }
 
@@ -806,7 +804,7 @@ def test_dirty_computed_var_from_var(interdependent_state):
         interdependent_state: A state with varying Var dependencies.
     """
     interdependent_state.v1 = 1
-    assert interdependent_state.get_delta(check=True) == {
+    assert interdependent_state.get_delta() == {
         interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
     }
 
@@ -818,7 +816,7 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
         interdependent_state: A state with varying Var dependencies.
     """
     interdependent_state._v2 = 2
-    assert interdependent_state.get_delta(check=True) == {
+    assert interdependent_state.get_delta() == {
         interdependent_state.get_full_name(): {"v2x2": 4},
     }
 
@@ -860,6 +858,7 @@ def test_conditional_computed_vars():
     assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
     assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
     assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
+    assert ms.computed_vars["rendered_var"].deps() == {"flag", "t1", "t2"}
 
 
 def test_event_handlers_convert_to_fns(test_state, child_state):
@@ -896,3 +895,29 @@ def test_event_handlers_call_other_handlers():
     ms = MainState()
     ms.set_v2(1)
     assert ms.v == 1
+
+
+def test_computed_var_cached():
+    """Test that a ComputedVar doesn't recalculate when accessed."""
+    comp_v_calls = 0
+
+    class ComputedState(State):
+        v: int = 0
+
+        @ComputedVar
+        def comp_v(self) -> int:
+            nonlocal comp_v_calls
+            comp_v_calls += 1
+            return self.v
+
+    cs = ComputedState()
+    assert cs.dict()["v"] == 0
+    assert comp_v_calls == 1
+    assert cs.dict()["comp_v"] == 0
+    assert comp_v_calls == 1
+    assert cs.comp_v == 0
+    assert comp_v_calls == 1
+    cs.v = 1
+    assert comp_v_calls == 1
+    assert cs.comp_v == 1
+    assert comp_v_calls == 2