Explorar o código

fix var dependency dicts (#3842)

benedikt-bartscher hai 8 meses
pai
achega
99ffbc8c39
Modificáronse 3 ficheiros con 64 adicións e 22 borrados
  1. 12 11
      reflex/state.py
  2. 1 10
      tests/test_app.py
  3. 51 1
      tests/test_state.py

+ 12 - 11
reflex/state.py

@@ -584,6 +584,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             cls.event_handlers[name] = handler
             setattr(cls, name, handler)
 
+        # Initialize per-class var dependency tracking.
+        cls._computed_var_dependencies = defaultdict(set)
+        cls._substate_var_dependencies = defaultdict(set)
         cls._init_var_dependency_dicts()
 
     @staticmethod
@@ -653,10 +656,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         Additional updates tracking dicts for vars and substates that always
         need to be recomputed.
         """
-        # Initialize per-class var dependency tracking.
-        cls._computed_var_dependencies = defaultdict(set)
-        cls._substate_var_dependencies = defaultdict(set)
-
         inherited_vars = set(cls.inherited_vars).union(
             set(cls.inherited_backend_vars),
         )
@@ -1006,20 +1005,20 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         Args:
             args: a dict of args
         """
+        if not args:
+            return
 
         def argsingle_factory(param):
-            @ComputedVar
             def inner_func(self) -> str:
                 return self.router.page.params.get(param, "")
 
-            return inner_func
+            return ComputedVar(fget=inner_func, cache=True)
 
         def arglist_factory(param):
-            @ComputedVar
             def inner_func(self) -> List:
                 return self.router.page.params.get(param, [])
 
-            return inner_func
+            return ComputedVar(fget=inner_func, cache=True)
 
         for param, value in args.items():
             if value == constants.RouteArgType.SINGLE:
@@ -1033,8 +1032,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls)  # type: ignore
             setattr(cls, param, func)
 
-            # Reinitialize dependency tracking dicts.
-            cls._init_var_dependency_dicts()
+        # Reinitialize dependency tracking dicts.
+        cls._init_var_dependency_dicts()
 
     def __getattribute__(self, name: str) -> Any:
         """Get the state var.
@@ -3600,5 +3599,7 @@ def reload_state_module(
         if subclass.__module__ == module and module is not None:
             state.class_subclasses.remove(subclass)
             state._always_dirty_substates.discard(subclass.get_name())
-    state._init_var_dependency_dicts()
+            state._computed_var_dependencies = defaultdict(set)
+            state._substate_var_dependencies = defaultdict(set)
+            state._init_var_dependency_dicts()
     state.get_class_substate.cache_clear()

+ 1 - 10
tests/test_app.py

@@ -906,7 +906,7 @@ class DynamicState(BaseState):
         """Increment the counter var."""
         self.counter = self.counter + 1
 
-    @computed_var
+    @computed_var(cache=True)
     def comp_dynamic(self) -> str:
         """A computed var that depends on the dynamic var.
 
@@ -1049,9 +1049,6 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         assert on_load_update == StateUpdate(
             delta={
                 state.get_name(): {
-                    # These computed vars _shouldn't_ be here, because they didn't change
-                    arg_name: exp_val,
-                    f"comp_{arg_name}": exp_val,
                     "loaded": exp_index + 1,
                 },
             },
@@ -1073,9 +1070,6 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         assert on_set_is_hydrated_update == StateUpdate(
             delta={
                 state.get_name(): {
-                    # These computed vars _shouldn't_ be here, because they didn't change
-                    arg_name: exp_val,
-                    f"comp_{arg_name}": exp_val,
                     "is_hydrated": True,
                 },
             },
@@ -1097,9 +1091,6 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         assert update == StateUpdate(
             delta={
                 state.get_name(): {
-                    # These computed vars _shouldn't_ be here, because they didn't change
-                    f"comp_{arg_name}": exp_val,
-                    arg_name: exp_val,
                     "counter": exp_index + 1,
                 }
             },

+ 51 - 1
tests/test_state.py

@@ -649,7 +649,12 @@ def test_set_dirty_var(test_state):
     assert test_state.dirty_vars == set()
 
 
-def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_state):
+def test_set_dirty_substate(
+    test_state: TestState,
+    child_state: ChildState,
+    child_state2: ChildState2,
+    grandchild_state: GrandchildState,
+):
     """Test changing substate vars marks the value as dirty.
 
     Args:
@@ -3077,6 +3082,51 @@ def test_potentially_dirty_substates():
     assert C1._potentially_dirty_substates() == set()
 
 
+def test_router_var_dep() -> None:
+    """Test that router var dependencies are correctly tracked."""
+
+    class RouterVarParentState(State):
+        """A parent state for testing router var dependency."""
+
+        pass
+
+    class RouterVarDepState(RouterVarParentState):
+        """A state with a router var dependency."""
+
+        @rx.var(cache=True)
+        def foo(self) -> str:
+            return self.router.page.params.get("foo", "")
+
+    foo = RouterVarDepState.computed_vars["foo"]
+    State._init_var_dependency_dicts()
+
+    assert foo._deps(objclass=RouterVarDepState) == {"router"}
+    assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
+    assert RouterVarParentState._substate_var_dependencies == {
+        "router": {RouterVarDepState.get_name()}
+    }
+    assert RouterVarDepState._computed_var_dependencies == {
+        "router": {"foo"},
+    }
+
+    rx_state = State()
+    parent_state = RouterVarParentState()
+    state = RouterVarDepState()
+
+    # link states
+    rx_state.substates = {RouterVarParentState.get_name(): parent_state}
+    parent_state.parent_state = rx_state
+    state.parent_state = parent_state
+    parent_state.substates = {RouterVarDepState.get_name(): state}
+
+    assert state.dirty_vars == set()
+
+    # Reassign router var
+    state.router = state.router
+    assert state.dirty_vars == {"foo", "router"}
+    assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
+
+
 @pytest.mark.asyncio
 async def test_setvar(mock_app: rx.App, token: str):
     """Test that setvar works correctly.