Ver código fonte

[REF-2045] Implement __reduce_ex__ for MutableProxy (#2688)

* test_state: augment modify_state test for writing MutableProxy

If the object contains a MutableProxy inside of it, then we get a pickling
error.

* Implement __reduce_ex__ for MutableProxy

Pass through `__reduce_ex__` onto the wrapped instance to strip it off when
cloudpickling to redis.

* base: get_value actually works with a str key

Unless the key isn't a field on the model, then it falls back to the previous
behavior of just returning the given key as is... why does it do this? I don't
know.
Masen Furer 1 ano atrás
pai
commit
953495775d
3 arquivos alterados com 24 adições e 1 exclusões
  1. 4 0
      reflex/base.py
  2. 15 0
      reflex/state.py
  3. 5 1
      tests/test_state.py

+ 4 - 0
reflex/base.py

@@ -115,6 +115,10 @@ class Base(pydantic.BaseModel):
         Returns:
             The value of the field.
         """
+        if isinstance(key, str) and key in self.__fields__:
+            # Seems like this function signature was wrong all along?
+            # If the user wants a field that we know of, get it and pass it off to _get_value
+            key = getattr(self, key)
         return self._get_value(
             key,
             to_dict=True,

+ 15 - 0
reflex/state.py

@@ -2363,6 +2363,21 @@ class MutableProxy(wrapt.ObjectProxy):
         """
         return copy.deepcopy(self.__wrapped__, memo=memo)
 
+    def __reduce_ex__(self, protocol_version):
+        """Get the state for redis serialization.
+
+        This method is called by cloudpickle to serialize the object.
+
+        It explicitly serializes the wrapped object, stripping off the mutable proxy.
+
+        Args:
+            protocol_version: The protocol version.
+
+        Returns:
+            Tuple of (wrapped class, empty args, class __getstate__)
+        """
+        return self.__wrapped__.__reduce_ex__(protocol_version)
+
 
 @serializer
 def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType:

+ 5 - 1
tests/test_state.py

@@ -1457,12 +1457,16 @@ async def test_state_manager_modify_state(
         token: A token.
         substate_token: A token + substate name for looking up in state manager.
     """
-    async with state_manager.modify_state(substate_token):
+    async with state_manager.modify_state(substate_token) as state:
         if isinstance(state_manager, StateManagerRedis):
             assert await state_manager.redis.get(f"{token}_lock")
         elif isinstance(state_manager, StateManagerMemory):
             assert token in state_manager._states_locks
             assert state_manager._states_locks[token].locked()
+        # Should be able to write proxy objects inside mutables
+        complex_1 = state.complex[1]
+        assert isinstance(complex_1, MutableProxy)
+        state.complex[3] = complex_1
     # lock should be dropped after exiting the context
     if isinstance(state_manager, StateManagerRedis):
         assert (await state_manager.redis.get(f"{token}_lock")) is None