Browse Source

[ENG-4083] Track internal changes in dataclass instances (#4558)

* [ENG-4083] Track internal changes in dataclass instances

Create a dynamic subclass of MutableProxy with `__dataclass_fields__` set
according to the dataclass being wrapped.

* support dataclasses.asdict on MutableProxy instances
Masen Furer 4 months ago
parent
commit
41cb2d8cff
2 changed files with 98 additions and 14 deletions
  1. 65 2
      reflex/state.py
  2. 33 12
      tests/units/test_state.py

+ 65 - 2
reflex/state.py

@@ -3649,6 +3649,9 @@ def get_state_manager() -> StateManager:
 class MutableProxy(wrapt.ObjectProxy):
     """A proxy for a mutable object that tracks changes."""
 
+    # Hint for finding the base class of the proxy.
+    __base_proxy__ = "MutableProxy"
+
     # Methods on wrapped objects which should mark the state as dirty.
     __mark_dirty_attrs__ = {
         "add",
@@ -3691,6 +3694,39 @@ class MutableProxy(wrapt.ObjectProxy):
         BaseModelV1,
     )
 
+    # Dynamically generated classes for tracking dataclass mutations.
+    __dataclass_proxies__: Dict[type, type] = {}
+
+    def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
+        """Create a proxy instance for a mutable object that tracks changes.
+
+        Args:
+            wrapped: The object to proxy.
+            *args: Other args passed to MutableProxy (ignored).
+            **kwargs: Other kwargs passed to MutableProxy (ignored).
+
+        Returns:
+            The proxy instance.
+        """
+        if dataclasses.is_dataclass(wrapped):
+            wrapped_cls = type(wrapped)
+            wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
+            # Find the associated class
+            if wrapper_cls_name not in cls.__dataclass_proxies__:
+                # Create a new class that has the __dataclass_fields__ defined
+                cls.__dataclass_proxies__[wrapper_cls_name] = type(
+                    wrapper_cls_name,
+                    (cls,),
+                    {
+                        dataclasses._FIELDS: getattr(  # pyright: ignore [reportGeneralTypeIssues]
+                            wrapped_cls,
+                            dataclasses._FIELDS,  # pyright: ignore [reportGeneralTypeIssues]
+                        ),
+                    },
+                )
+            cls = cls.__dataclass_proxies__[wrapper_cls_name]
+        return super().__new__(cls)
+
     def __init__(self, wrapped: Any, state: BaseState, field_name: str):
         """Create a proxy for a mutable object that tracks changes.
 
@@ -3747,7 +3783,27 @@ class MutableProxy(wrapt.ObjectProxy):
         Returns:
             Whether the value is of a mutable type.
         """
-        return isinstance(value, cls.__mutable_types__)
+        return isinstance(value, cls.__mutable_types__) or (
+            dataclasses.is_dataclass(value) and not isinstance(value, Var)
+        )
+
+    @staticmethod
+    def _is_called_from_dataclasses_internal() -> bool:
+        """Check if the current function is called from dataclasses helper.
+
+        Returns:
+            Whether the current function is called from dataclasses internal code.
+        """
+        # Walk up the stack a bit to see if we are called from dataclasses
+        # internal code, for example `asdict` or `astuple`.
+        frame = inspect.currentframe()
+        for _ in range(5):
+            # Why not `inspect.stack()` -- this is much faster!
+            if not (frame := frame and frame.f_back):
+                break
+            if inspect.getfile(frame) == dataclasses.__file__:
+                return True
+        return False
 
     def _wrap_recursive(self, value: Any) -> Any:
         """Wrap a value recursively if it is mutable.
@@ -3758,9 +3814,13 @@ class MutableProxy(wrapt.ObjectProxy):
         Returns:
             The wrapped value.
         """
+        # When called from dataclasses internal code, return the unwrapped value
+        if self._is_called_from_dataclasses_internal():
+            return value
         # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
         if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
-            return type(self)(
+            base_cls = globals()[self.__base_proxy__]
+            return base_cls(
                 wrapped=value,
                 state=self._self_state,
                 field_name=self._self_field_name,
@@ -3968,6 +4028,9 @@ class ImmutableMutableProxy(MutableProxy):
     to modify the wrapped object when the StateProxy is immutable.
     """
 
+    # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
+    __base_proxy__ = "ImmutableMutableProxy"
+
     def _mark_dirty(
         self,
         wrapped=None,

+ 33 - 12
tests/units/test_state.py

@@ -1936,6 +1936,14 @@ def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
     return mock_app_simple
 
 
+@dataclasses.dataclass
+class ModelDC:
+    """A dataclass."""
+
+    foo: str = "bar"
+    ls: list[dict] = dataclasses.field(default_factory=list)
+
+
 @pytest.mark.asyncio
 async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     """Test that the state proxy works.
@@ -2038,6 +2046,7 @@ class BackgroundTaskState(BaseState):
 
     order: List[str] = []
     dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
+    dc: ModelDC = ModelDC()
 
     def __init__(self, **kwargs):  # noqa: D107
         super().__init__(**kwargs)
@@ -2063,10 +2072,18 @@ class BackgroundTaskState(BaseState):
         with pytest.raises(ImmutableStateError):
             self.order.append("bad idea")
 
+        with pytest.raises(ImmutableStateError):
+            # Cannot manipulate dataclass attributes.
+            self.dc.foo = "baz"
+
         with pytest.raises(ImmutableStateError):
             # Even nested access to mutables raises an exception.
             self.dict_list["foo"].append(42)
 
+        with pytest.raises(ImmutableStateError):
+            # Cannot modify dataclass list attribute.
+            self.dc.ls.append({"foo": "bar"})
+
         with pytest.raises(ImmutableStateError):
             # Direct calling another handler that modifies state raises an exception.
             self.other()
@@ -3582,13 +3599,6 @@ class ModelV2(BaseModelV2):
     foo: str = "bar"
 
 
-@dataclasses.dataclass
-class ModelDC:
-    """A dataclass."""
-
-    foo: str = "bar"
-
-
 class PydanticState(rx.State):
     """A state with pydantic BaseModel vars."""
 
@@ -3610,11 +3620,22 @@ def test_mutable_models():
     assert state.dirty_vars == {"v2"}
     state.dirty_vars.clear()
 
-    # Not yet supported ENG-4083
-    # assert isinstance(state.dc, MutableProxy) #noqa: ERA001
-    # state.dc.foo = "baz" #noqa: ERA001
-    # assert state.dirty_vars == {"dc"} #noqa: ERA001
-    # state.dirty_vars.clear() #noqa: ERA001
+    assert isinstance(state.dc, MutableProxy)
+    state.dc.foo = "baz"
+    assert state.dirty_vars == {"dc"}
+    state.dirty_vars.clear()
+    assert state.dirty_vars == set()
+    state.dc.ls.append({"hi": "reflex"})
+    assert state.dirty_vars == {"dc"}
+    state.dirty_vars.clear()
+    assert state.dirty_vars == set()
+    assert dataclasses.asdict(state.dc) == {"foo": "baz", "ls": [{"hi": "reflex"}]}
+    assert dataclasses.astuple(state.dc) == ("baz", [{"hi": "reflex"}])
+    # creating a new instance shouldn't mark the state dirty
+    assert dataclasses.replace(state.dc, foo="quuc") == ModelDC(
+        foo="quuc", ls=[{"hi": "reflex"}]
+    )
+    assert state.dirty_vars == set()
 
 
 def test_get_value():