Bläddra i källkod

MutableProxy wraps values yielded by __iter__ (#1876)

Masen Furer 1 år sedan
förälder
incheckning
5ca7f29853
2 ändrade filer med 117 tillägg och 24 borttagningar
  1. 77 22
      reflex/state.py
  2. 40 2
      tests/test_state.py

+ 77 - 22
reflex/state.py

@@ -1647,6 +1647,13 @@ class MutableProxy(wrapt.ObjectProxy):
             "update",
         ]
     )
+    # Methods on wrapped objects might return mutable objects that should be tracked.
+    __wrap_mutable_attrs__ = set(
+        [
+            "get",
+            "setdefault",
+        ]
+    )
 
     __mutable_types__ = (list, dict, set, Base)
 
@@ -1663,7 +1670,13 @@ class MutableProxy(wrapt.ObjectProxy):
         self._self_state = state
         self._self_field_name = field_name
 
-    def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None):
+    def _mark_dirty(
+        self,
+        wrapped=None,
+        instance=None,
+        args=tuple(),
+        kwargs=None,
+    ) -> Any:
         """Mark the state as dirty, then call a wrapped function.
 
         Intended for use with `FunctionWrapper` from the `wrapt` library.
@@ -1673,11 +1686,47 @@ class MutableProxy(wrapt.ObjectProxy):
             instance: The instance of the wrapped function.
             args: The args for the wrapped function.
             kwargs: The kwargs for the wrapped function.
+
+        Returns:
+            The result of the wrapped function.
         """
         self._self_state.dirty_vars.add(self._self_field_name)
         self._self_state._mark_dirty()
         if wrapped is not None:
-            wrapped(*args, **(kwargs or {}))
+            return wrapped(*args, **(kwargs or {}))
+
+    def _wrap_recursive(self, value: Any) -> Any:
+        """Wrap a value recursively if it is mutable.
+
+        Args:
+            value: The value to wrap.
+
+        Returns:
+            The wrapped value.
+        """
+        if isinstance(value, self.__mutable_types__):
+            return type(self)(
+                wrapped=value,
+                state=self._self_state,
+                field_name=self._self_field_name,
+            )
+        return value
+
+    def _wrap_recursive_decorator(self, wrapped, instance, args, kwargs) -> Any:
+        """Wrap a function that returns a possibly mutable value.
+
+        Intended for use with `FunctionWrapper` from the `wrapt` library.
+
+        Args:
+            wrapped: The wrapped function.
+            instance: The instance of the wrapped function.
+            args: The args for the wrapped function.
+            kwargs: The kwargs for the wrapped function.
+
+        Returns:
+            The result of the wrapped function (possibly wrapped in a MutableProxy).
+        """
+        return self._wrap_recursive(wrapped(*args, **kwargs))
 
     def __getattribute__(self, __name: str) -> Any:
         """Get the attribute on the proxied object and return a proxy if mutable.
@@ -1690,24 +1739,26 @@ class MutableProxy(wrapt.ObjectProxy):
         """
         value = super().__getattribute__(__name)
 
-        if callable(value) and __name in super().__getattribute__(
-            "__mark_dirty_attrs__"
-        ):
-            # Wrap special callables, like "append", which should mark state dirty.
-            return wrapt.FunctionWrapper(
-                value,
-                super().__getattribute__("_mark_dirty"),
-            )
+        if callable(value):
+            if __name in super().__getattribute__("__mark_dirty_attrs__"):
+                # Wrap special callables, like "append", which should mark state dirty.
+                value = wrapt.FunctionWrapper(
+                    value,
+                    super().__getattribute__("_mark_dirty"),
+                )
+
+            if __name in super().__getattribute__("__wrap_mutable_attrs__"):
+                # Wrap methods that may return mutable objects tied to the state.
+                value = wrapt.FunctionWrapper(
+                    value,
+                    super().__getattribute__("_wrap_recursive_decorator"),
+                )
 
         if isinstance(
             value, super().__getattribute__("__mutable_types__")
         ) and __name not in ("__wrapped__", "_self_state"):
             # Recursively wrap mutable attribute values retrieved through this proxy.
-            return type(self)(
-                wrapped=value,
-                state=self._self_state,
-                field_name=self._self_field_name,
-            )
+            return self._wrap_recursive(value)
 
         return value
 
@@ -1721,14 +1772,18 @@ class MutableProxy(wrapt.ObjectProxy):
             The item value.
         """
         value = super().__getitem__(key)
-        if isinstance(value, self.__mutable_types__):
+        # Recursively wrap mutable items retrieved through this proxy.
+        return self._wrap_recursive(value)
+
+    def __iter__(self) -> Any:
+        """Iterate over the proxied object and return a proxy if mutable.
+
+        Yields:
+            Each item value (possibly wrapped in MutableProxy).
+        """
+        for value in super().__iter__():
             # Recursively wrap mutable items retrieved through this proxy.
-            return type(self)(
-                wrapped=value,
-                state=self._self_state,
-                field_name=self._self_field_name,
-            )
-        return value
+            yield self._wrap_recursive(value)
 
     def __delattr__(self, name):
         """Delete the attribute on the proxied object and mark state dirty.

+ 40 - 2
tests/test_state.py

@@ -1858,6 +1858,15 @@ def test_mutable_list(mutable_state):
     assert_array_dirty()
     assert isinstance(mutable_state.array[0], MutableProxy)
 
+    # Test proxy returned from __iter__
+    mutable_state.array = [{}]
+    assert_array_dirty()
+    assert isinstance(mutable_state.array[0], MutableProxy)
+    for item in mutable_state.array:
+        assert isinstance(item, MutableProxy)
+        item["foo"] = "bar"
+        assert_array_dirty()
+
 
 def test_mutable_dict(mutable_state):
     """Test that mutable dicts are tracked correctly.
@@ -1875,9 +1884,13 @@ def test_mutable_dict(mutable_state):
     # Test all dict operations
     mutable_state.hashmap.update({"new_key": 43})
     assert_hashmap_dirty()
-    mutable_state.hashmap.setdefault("another_key", 66)
+    assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value"
+    assert_hashmap_dirty()
+    assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67
     assert_hashmap_dirty()
-    mutable_state.hashmap.pop("new_key")
+    assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67
+    assert_hashmap_dirty()
+    assert mutable_state.hashmap.pop("new_key") == 43
     assert_hashmap_dirty()
     mutable_state.hashmap.popitem()
     assert_hashmap_dirty()
@@ -1905,6 +1918,31 @@ def test_mutable_dict(mutable_state):
     mutable_state.hashmap["dict"]["dict"]["key"] = 43
     assert_hashmap_dirty()
 
+    # Test proxy returned from `setdefault` and `get`
+    mutable_value = mutable_state.hashmap.setdefault("setdefault_mutable_key", [])
+    assert_hashmap_dirty()
+    assert mutable_value == []
+    assert isinstance(mutable_value, MutableProxy)
+    mutable_value.append("foo")
+    assert_hashmap_dirty()
+    mutable_value_other_ref = mutable_state.hashmap.get("setdefault_mutable_key")
+    assert isinstance(mutable_value_other_ref, MutableProxy)
+    assert mutable_value is not mutable_value_other_ref
+    assert mutable_value == mutable_value_other_ref
+    assert not mutable_state.dirty_vars
+    mutable_value_other_ref.append("bar")
+    assert_hashmap_dirty()
+
+    # `pop` should NOT return a proxy, because the returned value is no longer in the dict
+    mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key")
+    assert not isinstance(mutable_value_third_ref, MutableProxy)
+    assert_hashmap_dirty()
+    mutable_value_third_ref.append("baz")
+    assert not mutable_state.dirty_vars
+    # Unfortunately previous refs still will mark the state dirty... nothing doing about that
+    assert mutable_value.pop()
+    assert_hashmap_dirty()
+
 
 def test_mutable_set(mutable_state):
     """Test that mutable sets are tracked correctly.