Ver Fonte

raise StateSerializationError if the state cannot be serialized (#4453)

* raise StateSerializationError if the state cannot be serialized

* fix test
benedikt-bartscher há 5 meses atrás
pai
commit
2ee201b520
3 ficheiros alterados com 22 adições e 3 exclusões
  1. 10 0
      reflex/state.py
  2. 4 0
      reflex/utils/exceptions.py
  3. 8 3
      tests/units/test_state.py

+ 10 - 0
reflex/state.py

@@ -97,6 +97,7 @@ from reflex.utils.exceptions import (
     ReflexRuntimeError,
     ReflexRuntimeError,
     SetUndefinedStateVarError,
     SetUndefinedStateVarError,
     StateSchemaMismatchError,
     StateSchemaMismatchError,
+    StateSerializationError,
     StateTooLargeError,
     StateTooLargeError,
 )
 )
 from reflex.utils.exec import is_testing_env
 from reflex.utils.exec import is_testing_env
@@ -2193,8 +2194,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         Returns:
         Returns:
             The serialized state.
             The serialized state.
+
+        Raises:
+            StateSerializationError: If the state cannot be serialized.
         """
         """
         payload = b""
         payload = b""
+        error = ""
         try:
         try:
             payload = pickle.dumps((self._to_schema(), self))
             payload = pickle.dumps((self._to_schema(), self))
         except HANDLED_PICKLE_ERRORS as og_pickle_error:
         except HANDLED_PICKLE_ERRORS as og_pickle_error:
@@ -2214,8 +2219,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             except HANDLED_PICKLE_ERRORS as ex:
             except HANDLED_PICKLE_ERRORS as ex:
                 error += f"Dill was also unable to pickle the state: {ex}"
                 error += f"Dill was also unable to pickle the state: {ex}"
             console.warn(error)
             console.warn(error)
+
         if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
         if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
             self._check_state_size(len(payload))
             self._check_state_size(len(payload))
+
+        if not payload:
+            raise StateSerializationError(error)
+
         return payload
         return payload
 
 
     @classmethod
     @classmethod

+ 4 - 0
reflex/utils/exceptions.py

@@ -155,6 +155,10 @@ class StateTooLargeError(ReflexError):
     """Raised when the state is too large to be serialized."""
     """Raised when the state is too large to be serialized."""
 
 
 
 
+class StateSerializationError(ReflexError):
+    """Raised when the state cannot be serialized."""
+
+
 class SystemPackageMissingError(ReflexError):
 class SystemPackageMissingError(ReflexError):
     """Raised when a system package is missing."""
     """Raised when a system package is missing."""
 
 

+ 8 - 3
tests/units/test_state.py

@@ -55,7 +55,11 @@ from reflex.state import (
 )
 )
 from reflex.testing import chdir
 from reflex.testing import chdir
 from reflex.utils import format, prerequisites, types
 from reflex.utils import format, prerequisites, types
-from reflex.utils.exceptions import ReflexRuntimeError, SetUndefinedStateVarError
+from reflex.utils.exceptions import (
+    ReflexRuntimeError,
+    SetUndefinedStateVarError,
+    StateSerializationError,
+)
 from reflex.utils.format import json_dumps
 from reflex.utils.format import json_dumps
 from reflex.vars.base import Var, computed_var
 from reflex.vars.base import Var, computed_var
 from tests.units.states.mutation import MutableSQLAModel, MutableTestState
 from tests.units.states.mutation import MutableSQLAModel, MutableTestState
@@ -3433,8 +3437,9 @@ def test_fallback_pickle():
     # Some object, like generator, are still unpicklable with dill.
     # Some object, like generator, are still unpicklable with dill.
     state3 = DillState(_reflex_internal_init=True)  # type: ignore
     state3 = DillState(_reflex_internal_init=True)  # type: ignore
     state3._g = (i for i in range(10))
     state3._g = (i for i in range(10))
-    pk3 = state3._serialize()
-    assert len(pk3) == 0
+
+    with pytest.raises(StateSerializationError):
+        _ = state3._serialize()
 
 
 
 
 def test_typed_state() -> None:
 def test_typed_state() -> None: