Jelajahi Sumber

[REF-1035] Track ComputedVar dependency per class (#2067)

Masen Furer 1 tahun lalu
induk
melakukan
ee87e62efa
3 mengubah file dengan 108 tambahan dan 67 penghapusan
  1. 100 59
      reflex/state.py
  2. 2 2
      tests/test_app.py
  3. 6 6
      tests/test_state.py

+ 100 - 59
reflex/state.py

@@ -143,6 +143,15 @@ class RouterData(Base):
         self.page = PageData(router_data)
 
 
+RESERVED_BACKEND_VAR_NAMES = {
+    "_backend_vars",
+    "_computed_var_dependencies",
+    "_substate_var_dependencies",
+    "_always_dirty_computed_vars",
+    "_always_dirty_substates",
+}
+
+
 class State(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
 
@@ -167,6 +176,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # The event handlers.
     event_handlers: ClassVar[Dict[str, EventHandler]] = {}
 
+    # Mapping of var name to set of computed variables that depend on it
+    _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
+
+    # Mapping of var name to set of substates that depend on it
+    _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
+
+    # Set of vars which always need to be recomputed
+    _always_dirty_computed_vars: ClassVar[Set[str]] = set()
+
+    # Set of substates which always need to be recomputed
+    _always_dirty_substates: ClassVar[Set[str]] = set()
+
     # The parent state.
     parent_state: Optional[State] = None
 
@@ -182,12 +203,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # The routing path that triggered the state
     router_data: Dict[str, Any] = {}
 
-    # Mapping of var name to set of computed variables that depend on it
-    computed_var_dependencies: Dict[str, Set[str]] = {}
-
-    # Mapping of var name to set of substates that depend on it
-    substate_var_dependencies: Dict[str, Set[str]] = {}
-
     # Per-instance copy of backend variable values
     _backend_vars: Dict[str, Any] = {}
 
@@ -211,10 +226,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         kwargs["parent_state"] = parent_state
         super().__init__(*args, **kwargs)
 
-        # initialize per-instance var dependency tracking
-        self.computed_var_dependencies = defaultdict(set)
-        self.substate_var_dependencies = defaultdict(set)
-
         # Setup the substates.
         for substate in self.get_substates():
             substate_name = substate.get_name()
@@ -227,25 +238,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         # Convert the event handlers to functions.
         self._init_event_handlers()
 
-        # Initialize computed vars dependencies.
-        inherited_vars = set(self.inherited_vars).union(
-            set(self.inherited_backend_vars),
-        )
-        for cvar_name, cvar in self.computed_vars.items():
-            # Add the dependencies.
-            for var in cvar._deps(objclass=type(self)):
-                self.computed_var_dependencies[var].add(cvar_name)
-                if var in inherited_vars:
-                    # track that this substate depends on its parent for this var
-                    state_name = self.get_name()
-                    parent_state = self.parent_state
-                    while parent_state is not None and var in parent_state.vars:
-                        parent_state.substate_var_dependencies[var].add(state_name)
-                        state_name, parent_state = (
-                            parent_state.get_name(),
-                            parent_state.parent_state,
-                        )
-
         # Create a fresh copy of the backend variables for this instance
         self._backend_vars = copy.deepcopy(self.backend_vars)
 
@@ -347,6 +339,60 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             cls.event_handlers[name] = handler
             setattr(cls, name, handler)
 
+        cls._init_var_dependency_dicts()
+
+    @classmethod
+    def _init_var_dependency_dicts(cls):
+        """Initialize the var dependency tracking dicts.
+
+        Allows the state to know which vars each ComputedVar depends on and
+        whether a ComputedVar depends on a var in its parent state.
+
+        Additional updates tracking dicts for vars and substates that always
+        need to be recomputed.
+        """
+        # Initialize per-class var dependency tracking.
+        cls._computed_var_dependencies = defaultdict(set)
+        cls._substate_var_dependencies = defaultdict(set)
+
+        inherited_vars = set(cls.inherited_vars).union(
+            set(cls.inherited_backend_vars),
+        )
+        for cvar_name, cvar in cls.computed_vars.items():
+            # Add the dependencies.
+            for var in cvar._deps(objclass=cls):
+                cls._computed_var_dependencies[var].add(cvar_name)
+                if var in inherited_vars:
+                    # track that this substate depends on its parent for this var
+                    state_name = cls.get_name()
+                    parent_state = cls.get_parent_state()
+                    while parent_state is not None and var in parent_state.vars:
+                        parent_state._substate_var_dependencies[var].add(state_name)
+                        state_name, parent_state = (
+                            parent_state.get_name(),
+                            parent_state.get_parent_state(),
+                        )
+
+        # ComputedVar with cache=False always need to be recomputed
+        cls._always_dirty_computed_vars = set(
+            cvar_name
+            for cvar_name, cvar in cls.computed_vars.items()
+            if not cvar._cache
+        )
+
+        # Any substate containing a ComputedVar with cache=False always needs to be recomputed
+        cls._always_dirty_substates = set()
+        if cls._always_dirty_computed_vars:
+            # Tell parent classes that this substate has always dirty computed vars
+            state_name = cls.get_name()
+            parent_state = cls.get_parent_state()
+            while parent_state is not None:
+                parent_state._always_dirty_substates.add(state_name)
+                state_name, parent_state = (
+                    parent_state.get_name(),
+                    parent_state.get_parent_state(),
+                )
+
     @classmethod
     def _check_overridden_methods(cls):
         """Check for shadow methods and raise error if any.
@@ -377,16 +423,17 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The vars to skip when serializing.
         """
-        return set(cls.inherited_vars) | {
-            "parent_state",
-            "substates",
-            "dirty_vars",
-            "dirty_substates",
-            "router_data",
-            "computed_var_dependencies",
-            "substate_var_dependencies",
-            "_backend_vars",
-        }
+        return (
+            set(cls.inherited_vars)
+            | {
+                "parent_state",
+                "substates",
+                "dirty_vars",
+                "dirty_substates",
+                "router_data",
+            }
+            | RESERVED_BACKEND_VAR_NAMES
+        )
 
     @classmethod
     @functools.lru_cache()
@@ -540,6 +587,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         for substate_class in cls.__subclasses__():
             substate_class.vars.setdefault(name, var)
 
+        # Reinitialize dependency tracking dicts.
+        cls._init_var_dependency_dicts()
+
     @classmethod
     def _set_var(cls, prop: BaseVar):
         """Set the var as a class member.
@@ -749,6 +799,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls)  # type: ignore
             setattr(cls, param, func)
 
+            # Reinitialize dependency tracking dicts.
+            cls._init_var_dependency_dicts()
+
     def __getattribute__(self, name: str) -> Any:
         """Get the state var.
 
@@ -804,7 +857,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             setattr(self.parent_state, name, value)
             return
 
-        if types.is_backend_variable(name) and name != "_backend_vars":
+        if types.is_backend_variable(name) and name not in RESERVED_BACKEND_VAR_NAMES:
             self._backend_vars.__setitem__(name, value)
             self.dirty_vars.add(name)
             self._mark_dirty()
@@ -814,7 +867,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         super().__setattr__(name, value)
 
         # Add the var to the dirty list.
-        if name in self.vars or name in self.computed_var_dependencies:
+        if name in self.vars or name in self._computed_var_dependencies:
             self.dirty_vars.add(name)
             self._mark_dirty()
 
@@ -1056,18 +1109,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
                 final=True,
             )
 
-    def _always_dirty_computed_vars(self) -> set[str]:
-        """The set of ComputedVars that always need to be recalculated.
-
-        Returns:
-            Set of all ComputedVar in this state where cache=False
-        """
-        return set(
-            cvar_name
-            for cvar_name, cvar in self.computed_vars.items()
-            if not cvar._cache
-        )
-
     def _mark_dirty_computed_vars(self) -> None:
         """Mark ComputedVars that need to be recalculated based on dirty_vars."""
         dirty_vars = self.dirty_vars
@@ -1092,7 +1133,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         return set(
             cvar
             for dirty_var in from_vars or self.dirty_vars
-            for cvar in self.computed_var_dependencies[dirty_var]
+            for cvar in self._computed_var_dependencies[dirty_var]
         )
 
     def get_delta(self) -> Delta:
@@ -1104,7 +1145,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         delta = {}
 
         # Apply dirty variables down into substates
-        self.dirty_vars.update(self._always_dirty_computed_vars())
+        self.dirty_vars.update(self._always_dirty_computed_vars)
         self._mark_dirty()
 
         # Return the dirty vars for this instance, any cached/dependent computed vars,
@@ -1112,7 +1153,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         delta_vars = (
             self.dirty_vars.intersection(self.base_vars)
             .union(self._dirty_computed_vars())
-            .union(self._always_dirty_computed_vars())
+            .union(self._always_dirty_computed_vars)
         )
 
         subdelta = {
@@ -1125,7 +1166,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
         # Recursively find the substate deltas.
         substates = self.substates
-        for substate in self.dirty_substates:
+        for substate in self.dirty_substates.union(self._always_dirty_substates):
             delta.update(substates[substate].get_delta())
 
         # Format the delta.
@@ -1151,7 +1192,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         # Propagate dirty var / computed var status into substates
         substates = self.substates
         for var in self.dirty_vars:
-            for substate_name in self.substate_var_dependencies[var]:
+            for substate_name in self._substate_var_dependencies[var]:
                 self.dirty_substates.add(substate_name)
                 substate = substates[substate_name]
                 substate.dirty_vars.add(var)
@@ -1195,7 +1236,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         if include_computed:
             # Apply dirty variables down into substates to allow never-cached ComputedVar to
             # trigger recalculation of dependent vars
-            self.dirty_vars.update(self._always_dirty_computed_vars())
+            self.dirty_vars.update(self._always_dirty_computed_vars)
             self._mark_dirty()
 
         base_vars = {

+ 2 - 2
tests/test_app.py

@@ -257,7 +257,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
     assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
         constants.ROUTER
     }
-    assert constants.ROUTER in app.state().computed_var_dependencies
+    assert constants.ROUTER in app.state()._computed_var_dependencies
 
 
 def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
@@ -917,7 +917,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
         constants.ROUTER
     }
-    assert constants.ROUTER in app.state().computed_var_dependencies
+    assert constants.ROUTER in app.state()._computed_var_dependencies
 
     sid = "mock_sid"
     client_ip = "127.0.0.1"

+ 6 - 6
tests/test_state.py

@@ -1215,7 +1215,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
         assert isinstance(HandlerState.handler, EventHandler)
 
     s = HandlerState()
-    assert "cached_x_side_effect" in s.computed_var_dependencies["x"]
+    assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
     assert s.cached_x_side_effect == 1
     assert s.x == 43
     s.handler()
@@ -1283,11 +1283,11 @@ def test_computed_var_dependencies():
             return [z in self._z for z in range(5)]
 
     cs = ComputedState()
-    assert cs.computed_var_dependencies["v"] == {"comp_v"}
-    assert cs.computed_var_dependencies["w"] == {"comp_w"}
-    assert cs.computed_var_dependencies["x"] == {"comp_x"}
-    assert cs.computed_var_dependencies["y"] == {"comp_y"}
-    assert cs.computed_var_dependencies["_z"] == {"comp_z"}
+    assert cs._computed_var_dependencies["v"] == {"comp_v"}
+    assert cs._computed_var_dependencies["w"] == {"comp_w"}
+    assert cs._computed_var_dependencies["x"] == {"comp_x"}
+    assert cs._computed_var_dependencies["y"] == {"comp_y"}
+    assert cs._computed_var_dependencies["_z"] == {"comp_z"}
 
 
 def test_backend_method():