瀏覽代碼

[REF-1356] Track changes applied to `Base` subclass via helper method. (#2242)

Masen Furer 11 月之前
父節點
當前提交
b04e3a6ce9
共有 2 個文件被更改,包括 85 次插入1 次删除
  1. 20 1
      reflex/state.py
  2. 65 0
      tests/test_state.py

+ 20 - 1
reflex/state.py

@@ -2867,6 +2867,11 @@ class MutableProxy(wrapt.ObjectProxy):
         ]
     )
 
+    # These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
+    __never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
+        pydantic.BaseModel.__dict__
+    )
+
     __mutable_types__ = (list, dict, set, Base)
 
     def __init__(self, wrapped: Any, state: BaseState, field_name: str):
@@ -2916,7 +2921,10 @@ class MutableProxy(wrapt.ObjectProxy):
         Returns:
             The wrapped value.
         """
-        if isinstance(value, self.__mutable_types__):
+        # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
+        if isinstance(value, self.__mutable_types__) and not isinstance(
+            value, MutableProxy
+        ):
             return type(self)(
                 wrapped=value,
                 state=self._self_state,
@@ -2963,6 +2971,17 @@ class MutableProxy(wrapt.ObjectProxy):
                     self._wrap_recursive_decorator,
                 )
 
+            if (
+                isinstance(self.__wrapped__, Base)
+                and __name not in self.__never_wrap_base_attrs__
+                and hasattr(value, "__func__")
+            ):
+                # Wrap methods called on Base subclasses, which might do _anything_
+                return wrapt.FunctionWrapper(
+                    functools.partial(value.__func__, self),
+                    self._wrap_recursive_decorator,
+                )
+
         if isinstance(value, self.__mutable_types__) and __name not in (
             "__wrapped__",
             "_self_state",

+ 65 - 0
tests/test_state.py

@@ -2392,6 +2392,22 @@ class Custom1(Base):
 
     foo: str
 
+    def set_foo(self, val: str):
+        """Set the attribute foo.
+
+        Args:
+            val: The value to set.
+        """
+        self.foo = val
+
+    def double_foo(self) -> str:
+        """Concantenate foo with foo.
+
+        Returns:
+            foo + foo
+        """
+        return self.foo + self.foo
+
 
 class Custom2(Base):
     """A custom class with a Custom1 field."""
@@ -2399,6 +2415,14 @@ class Custom2(Base):
     c1: Optional[Custom1] = None
     c1r: Custom1
 
+    def set_c1r_foo(self, val: str):
+        """Set the foo attribute of the c1 field.
+
+        Args:
+            val: The value to set.
+        """
+        self.c1r.set_foo(val)
+
 
 class Custom3(Base):
     """A custom class with a Custom2 field."""
@@ -2436,6 +2460,47 @@ def test_state_union_optional():
     assert types.is_union(UnionState.int_float._var_type)  # type: ignore
 
 
+def test_set_base_field_via_setter():
+    """When calling a setter on a Base instance, also track changes."""
+
+    class BaseFieldSetterState(BaseState):
+        c1: Custom1 = Custom1(foo="")
+        c2: Custom2 = Custom2(c1r=Custom1(foo=""))
+
+    bfss = BaseFieldSetterState()
+    assert "c1" not in bfss.dirty_vars
+
+    # Non-mutating function, not dirty
+    bfss.c1.double_foo()
+    assert "c1" not in bfss.dirty_vars
+
+    # Mutating function, dirty
+    bfss.c1.set_foo("bar")
+    assert "c1" in bfss.dirty_vars
+    bfss.dirty_vars.clear()
+    assert "c1" not in bfss.dirty_vars
+
+    # Mutating function from Base, dirty
+    bfss.c1.set(foo="bar")
+    assert "c1" in bfss.dirty_vars
+    bfss.dirty_vars.clear()
+    assert "c1" not in bfss.dirty_vars
+
+    # Assert identity of MutableProxy
+    mp = bfss.c1
+    assert isinstance(mp, MutableProxy)
+    mp2 = mp.set()
+    assert mp is mp2
+    mp3 = bfss.c1.set()
+    assert mp is not mp3
+    # Since none of these set calls had values, the state should not be dirty
+    assert not bfss.dirty_vars
+
+    # Chained Mutating function, dirty
+    bfss.c2.set_c1r_foo("baz")
+    assert "c2" in bfss.dirty_vars
+
+
 def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
     """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.