|
@@ -3649,6 +3649,9 @@ def get_state_manager() -> StateManager:
|
|
|
class MutableProxy(wrapt.ObjectProxy):
|
|
|
"""A proxy for a mutable object that tracks changes."""
|
|
|
|
|
|
+ # Hint for finding the base class of the proxy.
|
|
|
+ __base_proxy__ = "MutableProxy"
|
|
|
+
|
|
|
# Methods on wrapped objects which should mark the state as dirty.
|
|
|
__mark_dirty_attrs__ = {
|
|
|
"add",
|
|
@@ -3691,6 +3694,39 @@ class MutableProxy(wrapt.ObjectProxy):
|
|
|
BaseModelV1,
|
|
|
)
|
|
|
|
|
|
+ # Dynamically generated classes for tracking dataclass mutations.
|
|
|
+ __dataclass_proxies__: Dict[type, type] = {}
|
|
|
+
|
|
|
+ def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
|
|
|
+ """Create a proxy instance for a mutable object that tracks changes.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ wrapped: The object to proxy.
|
|
|
+ *args: Other args passed to MutableProxy (ignored).
|
|
|
+ **kwargs: Other kwargs passed to MutableProxy (ignored).
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The proxy instance.
|
|
|
+ """
|
|
|
+ if dataclasses.is_dataclass(wrapped):
|
|
|
+ wrapped_cls = type(wrapped)
|
|
|
+ wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
|
|
|
+ # Find the associated class
|
|
|
+ if wrapper_cls_name not in cls.__dataclass_proxies__:
|
|
|
+ # Create a new class that has the __dataclass_fields__ defined
|
|
|
+ cls.__dataclass_proxies__[wrapper_cls_name] = type(
|
|
|
+ wrapper_cls_name,
|
|
|
+ (cls,),
|
|
|
+ {
|
|
|
+ dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues]
|
|
|
+ wrapped_cls,
|
|
|
+ dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues]
|
|
|
+ ),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ cls = cls.__dataclass_proxies__[wrapper_cls_name]
|
|
|
+ return super().__new__(cls)
|
|
|
+
|
|
|
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
|
|
"""Create a proxy for a mutable object that tracks changes.
|
|
|
|
|
@@ -3747,7 +3783,27 @@ class MutableProxy(wrapt.ObjectProxy):
|
|
|
Returns:
|
|
|
Whether the value is of a mutable type.
|
|
|
"""
|
|
|
- return isinstance(value, cls.__mutable_types__)
|
|
|
+ return isinstance(value, cls.__mutable_types__) or (
|
|
|
+ dataclasses.is_dataclass(value) and not isinstance(value, Var)
|
|
|
+ )
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _is_called_from_dataclasses_internal() -> bool:
|
|
|
+ """Check if the current function is called from dataclasses helper.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Whether the current function is called from dataclasses internal code.
|
|
|
+ """
|
|
|
+ # Walk up the stack a bit to see if we are called from dataclasses
|
|
|
+ # internal code, for example `asdict` or `astuple`.
|
|
|
+ frame = inspect.currentframe()
|
|
|
+ for _ in range(5):
|
|
|
+ # Why not `inspect.stack()` -- this is much faster!
|
|
|
+ if not (frame := frame and frame.f_back):
|
|
|
+ break
|
|
|
+ if inspect.getfile(frame) == dataclasses.__file__:
|
|
|
+ return True
|
|
|
+ return False
|
|
|
|
|
|
def _wrap_recursive(self, value: Any) -> Any:
|
|
|
"""Wrap a value recursively if it is mutable.
|
|
@@ -3758,9 +3814,13 @@ class MutableProxy(wrapt.ObjectProxy):
|
|
|
Returns:
|
|
|
The wrapped value.
|
|
|
"""
|
|
|
+ # When called from dataclasses internal code, return the unwrapped value
|
|
|
+ if self._is_called_from_dataclasses_internal():
|
|
|
+ return value
|
|
|
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
|
|
|
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
|
|
|
- return type(self)(
|
|
|
+ base_cls = globals()[self.__base_proxy__]
|
|
|
+ return base_cls(
|
|
|
wrapped=value,
|
|
|
state=self._self_state,
|
|
|
field_name=self._self_field_name,
|
|
@@ -3968,6 +4028,9 @@ class ImmutableMutableProxy(MutableProxy):
|
|
|
to modify the wrapped object when the StateProxy is immutable.
|
|
|
"""
|
|
|
|
|
|
+ # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
|
|
|
+ __base_proxy__ = "ImmutableMutableProxy"
|
|
|
+
|
|
|
def _mark_dirty(
|
|
|
self,
|
|
|
wrapped=None,
|