瀏覽代碼

state: implement __copy__ and __deepcopy__ for MutableProxy (#1845)

Masen Furer 1 年之前
父節點
當前提交
1bfb579b20
共有 2 個文件被更改,包括 70 次插入0 次删除
  1. 19 0
      reflex/state.py
  2. 51 0
      tests/test_state.py

+ 19 - 0
reflex/state.py

@@ -1313,3 +1313,22 @@ class MutableProxy(wrapt.ObjectProxy):
             super().__setattr__(name, value)
             return
         self._mark_dirty(super().__setattr__, args=(name, value))
+
+    def __copy__(self) -> Any:
+        """Return a copy of the proxy.
+
+        Returns:
+            A copy of the wrapped object, unconnected to the proxy.
+        """
+        return copy.copy(self.__wrapped__)
+
+    def __deepcopy__(self, memo=None) -> Any:
+        """Return a deepcopy of the proxy.
+
+        Args:
+            memo: The memo dict to use for the deepcopy.
+
+        Returns:
+            A deepcopy of the wrapped object, unconnected to the proxy.
+        """
+        return copy.deepcopy(self.__wrapped__, memo=memo)

+ 51 - 0
tests/test_state.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import copy
 import datetime
 import functools
 import sys
@@ -1585,6 +1586,56 @@ def test_mutable_backend(mutable_state):
     assert_custom_dirty()
 
 
+@pytest.mark.parametrize(
+    ("copy_func",),
+    [
+        (copy.copy,),
+        (copy.deepcopy,),
+    ],
+)
+def test_mutable_copy(mutable_state, copy_func):
+    """Test that mutable types are copied correctly.
+
+    Args:
+        mutable_state: A test state.
+        copy_func: A copy function.
+    """
+    ms_copy = copy_func(mutable_state)
+    assert ms_copy is not mutable_state
+    for attr in ("array", "hashmap", "test_set", "custom"):
+        assert getattr(ms_copy, attr) == getattr(mutable_state, attr)
+        assert getattr(ms_copy, attr) is not getattr(mutable_state, attr)
+    ms_copy.custom.array.append(42)
+    assert "custom" in ms_copy.dirty_vars
+    if copy_func is copy.copy:
+        assert "custom" in mutable_state.dirty_vars
+    else:
+        assert not mutable_state.dirty_vars
+
+
+@pytest.mark.parametrize(
+    ("copy_func",),
+    [
+        (copy.copy,),
+        (copy.deepcopy,),
+    ],
+)
+def test_mutable_copy_vars(mutable_state, copy_func):
+    """Test that mutable types are copied correctly.
+
+    Args:
+        mutable_state: A test state.
+        copy_func: A copy function.
+    """
+    for attr in ("array", "hashmap", "test_set", "custom"):
+        var_orig = getattr(mutable_state, attr)
+        var_copy = copy_func(var_orig)
+        assert var_orig is not var_copy
+        assert var_orig == var_copy
+        # copied vars should never be proxies, as they by definition are no longer attached to the state.
+        assert not isinstance(var_copy, MutableProxy)
+
+
 def test_duplicate_substate_class(duplicate_substate):
     with pytest.raises(ValueError):
         duplicate_substate()