Browse Source

@pc.cached_var: explicit opt-in for ComputedVar tracking (#1000)

Masen Furer 2 years ago
parent
commit
0491852a45
4 changed files with 182 additions and 63 deletions
  1. 1 0
      pynecone/__init__.py
  2. 43 25
      pynecone/state.py
  3. 21 1
      pynecone/vars.py
  4. 117 37
      tests/test_state.py

+ 1 - 0
pynecone/__init__.py

@@ -31,3 +31,4 @@ from .state import ComputedVar as var
 from .state import State as State
 from .style import toggle_color_mode as toggle_color_mode
 from .vars import Var as Var
+from .vars import cached_var as cached_var

+ 43 - 25
pynecone/state.py

@@ -688,22 +688,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         # Return the state update.
         return StateUpdate(delta=delta, events=events)
 
+    def _always_dirty_computed_vars(self) -> Set[str]:
+        """The set of ComputedVars that always need to be recalculated.
+
+        Returns:
+            Set of all ComputedVar in this state where cache=False
+        """
+        return set(
+            cvar_name
+            for cvar_name, cvar in self.computed_vars.items()
+            if not cvar.cache
+        )
+
     def _mark_dirty_computed_vars(self) -> None:
         """Mark ComputedVars that need to be recalculated based on dirty_vars."""
-        # Mark all ComputedVars as dirty.
-        for cvar in self.computed_vars.values():
-            cvar.mark_dirty(instance=self)
-
-        # TODO: Uncomment the actual implementation below.
-        # dirty_vars = self.dirty_vars
-        # while dirty_vars:
-        #     calc_vars, dirty_vars = dirty_vars, set()
-        #     for cvar in self._dirty_computed_vars(from_vars=calc_vars):
-        #         self.dirty_vars.add(cvar)
-        #         dirty_vars.add(cvar)
-        #         actual_var = self.computed_vars.get(cvar)
-        #         if actual_var:
-        #             actual_var.mark_dirty(instance=self)
+        dirty_vars = self.dirty_vars
+        while dirty_vars:
+            calc_vars, dirty_vars = dirty_vars, set()
+            for cvar in self._dirty_computed_vars(from_vars=calc_vars):
+                self.dirty_vars.add(cvar)
+                dirty_vars.add(cvar)
+                actual_var = self.computed_vars.get(cvar)
+                if actual_var:
+                    actual_var.mark_dirty(instance=self)
 
     def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
         """Determine ComputedVars that need to be recalculated based on the given vars.
@@ -714,13 +721,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             Set of computed vars to include in the delta.
         """
-        return set(self.computed_vars)
-        # TODO: Uncomment the actual implementation below.
-        # return set(
-        #     cvar
-        #     for dirty_var in from_vars or self.dirty_vars
-        #     for cvar in self.computed_var_dependencies[dirty_var]
-        # )
+        return set(
+            cvar
+            for dirty_var in from_vars or self.dirty_vars
+            for cvar in self.computed_var_dependencies[dirty_var]
+        )
 
     def get_delta(self) -> Delta:
         """Get the delta for the state.
@@ -730,11 +735,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         """
         delta = {}
 
-        # Return the dirty vars and dependent computed vars
-        self._mark_dirty_computed_vars()
-        delta_vars = self.dirty_vars.intersection(self.base_vars).union(
-            self._dirty_computed_vars()
+        # Apply dirty variables down into substates
+        self.dirty_vars.update(self._always_dirty_computed_vars())
+        self.mark_dirty()
+
+        # Return the dirty vars for this instance, any cached/dependent computed vars,
+        # and always dirty computed vars (cache=False)
+        delta_vars = (
+            self.dirty_vars.intersection(self.base_vars)
+            .union(self._dirty_computed_vars())
+            .union(self._always_dirty_computed_vars())
         )
+
         subdelta = {
             prop: getattr(self, prop)
             for prop in delta_vars
@@ -797,6 +809,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The object as a dictionary.
         """
+        if include_computed:
+            # Apply dirty variables down into substates to allow never-cached ComputedVar to
+            # trigger recalculation of dependent vars
+            self.dirty_vars.update(self._always_dirty_computed_vars())
+            self.mark_dirty()
+
         base_vars = {
             prop_name: self.get_value(getattr(self, prop_name))
             for prop_name in self.base_vars

+ 21 - 1
pynecone/vars.py

@@ -801,6 +801,9 @@ class BaseVar(Var, Base):
 class ComputedVar(Var, property):
     """A field with computed getters."""
 
+    # Whether to track dependencies and cache computed values
+    cache: bool = False
+
     @property
     def name(self) -> str:
         """Get the name of the var.
@@ -832,7 +835,7 @@ class ComputedVar(Var, property):
         Returns:
             The value of the var for the given instance.
         """
-        if instance is None:
+        if instance is None or not self.cache:
             return super().__get__(instance, owner)
 
         # handle caching
@@ -906,6 +909,23 @@ class ComputedVar(Var, property):
         return Any
 
 
+def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
+    """A field with computed getter that tracks other state dependencies.
+
+    The cached_var will only be recalculated when other state vars that it
+    depends on are modified.
+
+    Args:
+        fget: the function that calculates the variable value.
+
+    Returns:
+        ComputedVar that is recomputed when dependencies change.
+    """
+    cvar = ComputedVar(fget=fget)
+    cvar.cache = True
+    return cvar
+
+
 class PCList(list):
     """A custom list that pynecone can detect its mutation."""
 

+ 117 - 37
tests/test_state.py

@@ -3,6 +3,7 @@ from typing import Dict, List
 import pytest
 from plotly.graph_objects import Figure
 
+import pynecone as pc
 from pynecone.base import Base
 from pynecone.constants import IS_HYDRATED, RouteVar
 from pynecone.event import Event, EventHandler
@@ -484,13 +485,11 @@ def test_set_dirty_var(test_state):
 
     # Setting a var should mark it as dirty.
     test_state.num1 = 1
-    # assert test_state.dirty_vars == {"num1", "sum"}
-    assert test_state.dirty_vars == {"num1"}
+    assert test_state.dirty_vars == {"num1", "sum"}
 
     # Setting another var should mark it as dirty.
     test_state.num2 = 2
-    # assert test_state.dirty_vars == {"num1", "num2", "sum"}
-    assert test_state.dirty_vars == {"num1", "num2"}
+    assert test_state.dirty_vars == {"num1", "num2", "sum"}
 
     # Cleaning the state should remove all dirty vars.
     test_state.clean()
@@ -746,7 +745,7 @@ class InterdependentState(State):
     v1: int = 0
     _v2: int = 1
 
-    @ComputedVar
+    @pc.cached_var
     def v1x2(self) -> int:
         """Depends on var v1.
 
@@ -755,7 +754,7 @@ class InterdependentState(State):
         """
         return self.v1 * 2
 
-    @ComputedVar
+    @pc.cached_var
     def v2x2(self) -> int:
         """Depends on backend var _v2.
 
@@ -764,7 +763,7 @@ class InterdependentState(State):
         """
         return self._v2 * 2
 
-    @ComputedVar
+    @pc.cached_var
     def v1x2x2(self) -> int:
         """Depends on ComputedVar v1x2.
 
@@ -786,43 +785,43 @@ def interdependent_state() -> State:
     return s
 
 
-# def test_not_dirty_computed_var_from_var(interdependent_state):
-#     """Set Var that no ComputedVar depends on, expect no recalculation.
+def test_not_dirty_computed_var_from_var(interdependent_state):
+    """Set Var that no ComputedVar depends on, expect no recalculation.
 
-#     Args:
-#         interdependent_state: A state with varying Var dependencies.
-#     """
-#     interdependent_state.x = 5
-#     assert interdependent_state.get_delta() == {
-#         interdependent_state.get_full_name(): {"x": 5},
-#     }
+    Args:
+        interdependent_state: A state with varying Var dependencies.
+    """
+    interdependent_state.x = 5
+    assert interdependent_state.get_delta() == {
+        interdependent_state.get_full_name(): {"x": 5},
+    }
 
 
-# def test_dirty_computed_var_from_var(interdependent_state):
-#     """Set Var that ComputedVar depends on, expect recalculation.
+def test_dirty_computed_var_from_var(interdependent_state):
+    """Set Var that ComputedVar depends on, expect recalculation.
 
-#     The other ComputedVar depends on the changed ComputedVar and should also be
-#     recalculated. No other ComputedVars should be recalculated.
+    The other ComputedVar depends on the changed ComputedVar and should also be
+    recalculated. No other ComputedVars should be recalculated.
 
-#     Args:
-#         interdependent_state: A state with varying Var dependencies.
-#     """
-#     interdependent_state.v1 = 1
-#     assert interdependent_state.get_delta() == {
-#         interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
-#     }
+    Args:
+        interdependent_state: A state with varying Var dependencies.
+    """
+    interdependent_state.v1 = 1
+    assert interdependent_state.get_delta() == {
+        interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
+    }
 
 
-# def test_dirty_computed_var_from_backend_var(interdependent_state):
-#     """Set backend var that ComputedVar depends on, expect recalculation.
+def test_dirty_computed_var_from_backend_var(interdependent_state):
+    """Set backend var that ComputedVar depends on, expect recalculation.
 
-#     Args:
-#         interdependent_state: A state with varying Var dependencies.
-#     """
-#     interdependent_state._v2 = 2
-#     assert interdependent_state.get_delta() == {
-#         interdependent_state.get_full_name(): {"v2x2": 4},
-#     }
+    Args:
+        interdependent_state: A state with varying Var dependencies.
+    """
+    interdependent_state._v2 = 2
+    assert interdependent_state.get_delta() == {
+        interdependent_state.get_full_name(): {"v2x2": 4},
+    }
 
 
 def test_per_state_backend_var(interdependent_state):
@@ -932,7 +931,7 @@ def test_computed_var_cached():
     class ComputedState(State):
         v: int = 0
 
-        @ComputedVar
+        @pc.cached_var
         def comp_v(self) -> int:
             nonlocal comp_v_calls
             comp_v_calls += 1
@@ -949,3 +948,84 @@ def test_computed_var_cached():
     assert comp_v_calls == 1
     assert cs.comp_v == 1
     assert comp_v_calls == 2
+
+
+def test_computed_var_cached_depends_on_non_cached():
+    """Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
+
+    class ComputedState(State):
+        v: int = 0
+
+        @pc.var
+        def no_cache_v(self) -> int:
+            return self.v
+
+        @pc.cached_var
+        def dep_v(self) -> int:
+            return self.no_cache_v
+
+        @pc.cached_var
+        def comp_v(self) -> int:
+            return self.v
+
+    cs = ComputedState()
+    assert cs.dirty_vars == set()
+    assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
+    cs.clean()
+    assert cs.dirty_vars == set()
+    assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
+    cs.clean()
+    assert cs.dirty_vars == set()
+    cs.v = 1
+    assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"}
+    assert cs.get_delta() == {
+        cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1}
+    }
+    cs.clean()
+    assert cs.dirty_vars == set()
+    assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
+    cs.clean()
+    assert cs.dirty_vars == set()
+    assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
+    cs.clean()
+    assert cs.dirty_vars == set()
+
+
+def test_computed_var_depends_on_parent_non_cached():
+    """Child state cached_var that depends on parent state un cached var is always recalculated."""
+    counter = 0
+
+    class ParentState(State):
+        @pc.var
+        def no_cache_v(self) -> int:
+            nonlocal counter
+            counter += 1
+            return counter
+
+    class ChildState(ParentState):
+        @pc.cached_var
+        def dep_v(self) -> int:
+            return self.no_cache_v
+
+    ps = ParentState()
+    cs = ps.substates[ChildState.get_name()]
+
+    assert ps.dirty_vars == set()
+    assert cs.dirty_vars == set()
+
+    assert ps.dict() == {
+        cs.get_name(): {"dep_v": 2},
+        "no_cache_v": 1,
+        IS_HYDRATED: False,
+    }
+    assert ps.dict() == {
+        cs.get_name(): {"dep_v": 4},
+        "no_cache_v": 3,
+        IS_HYDRATED: False,
+    }
+    assert ps.dict() == {
+        cs.get_name(): {"dep_v": 6},
+        "no_cache_v": 5,
+        IS_HYDRATED: False,
+    }
+    assert counter == 6