|
@@ -691,6 +691,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
parent_state.get_parent_state(),
|
|
|
)
|
|
|
|
|
|
+ # Reset cached schema value
|
|
|
+ cls._to_schema.cache_clear()
|
|
|
+
|
|
|
@classmethod
|
|
|
def _check_overridden_methods(cls):
|
|
|
"""Check for shadow methods and raise error if any.
|
|
@@ -1945,20 +1948,58 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
The state dict for serialization.
|
|
|
"""
|
|
|
state = super().__getstate__()
|
|
|
- # Never serialize parent_state or substates
|
|
|
state["__dict__"] = state["__dict__"].copy()
|
|
|
+ if state["__dict__"].get("parent_state") is not None:
|
|
|
+ # Do not serialize router data in substates (only the root state).
|
|
|
+ state["__dict__"].pop("router", None)
|
|
|
+ state["__dict__"].pop("router_data", None)
|
|
|
+ # Never serialize parent_state or substates.
|
|
|
state["__dict__"]["parent_state"] = None
|
|
|
state["__dict__"]["substates"] = {}
|
|
|
state["__dict__"].pop("_was_touched", None)
|
|
|
+ # Remove all inherited vars.
|
|
|
+ for inherited_var_name in self.inherited_vars:
|
|
|
+ state["__dict__"].pop(inherited_var_name, None)
|
|
|
return state
|
|
|
|
|
|
+ @classmethod
|
|
|
+ @functools.lru_cache()
|
|
|
+ def _to_schema(cls) -> str:
|
|
|
+ """Convert a state to a schema.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The hash of the schema.
|
|
|
+ """
|
|
|
+
|
|
|
+ def _field_tuple(
|
|
|
+ field_name: str,
|
|
|
+ ) -> Tuple[str, str, Any, Union[bool, None], Any]:
|
|
|
+ model_field = cls.__fields__[field_name]
|
|
|
+ return (
|
|
|
+ field_name,
|
|
|
+ model_field.name,
|
|
|
+ _serialize_type(model_field.type_),
|
|
|
+ (
|
|
|
+ model_field.required
|
|
|
+ if isinstance(model_field.required, bool)
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ (model_field.default if is_serializable(model_field.default) else None),
|
|
|
+ )
|
|
|
+
|
|
|
+ return md5(
|
|
|
+ pickle.dumps(
|
|
|
+ list(sorted(_field_tuple(field_name) for field_name in cls.base_vars))
|
|
|
+ )
|
|
|
+ ).hexdigest()
|
|
|
+
|
|
|
def _serialize(self) -> bytes:
|
|
|
"""Serialize the state for redis.
|
|
|
|
|
|
Returns:
|
|
|
The serialized state.
|
|
|
"""
|
|
|
- return pickle.dumps((state_to_schema(self), self))
|
|
|
+ return pickle.dumps((self._to_schema(), self))
|
|
|
|
|
|
@classmethod
|
|
|
def _deserialize(
|
|
@@ -1985,7 +2026,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
(substate_schema, state) = pickle.load(fp)
|
|
|
else:
|
|
|
raise ValueError("Only one of `data` or `fp` must be provided")
|
|
|
- if substate_schema != state_to_schema(state):
|
|
|
+ if substate_schema != state._to_schema():
|
|
|
raise StateSchemaMismatchError()
|
|
|
return state
|
|
|
|
|
@@ -2620,35 +2661,6 @@ def is_serializable(value: Any) -> bool:
|
|
|
return False
|
|
|
|
|
|
|
|
|
-def state_to_schema(
|
|
|
- state: BaseState,
|
|
|
-) -> List[Tuple[str, str, Any, Union[bool, None], Any]]:
|
|
|
- """Convert a state to a schema.
|
|
|
-
|
|
|
- Args:
|
|
|
- state: The state to convert to a schema.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The schema.
|
|
|
- """
|
|
|
- return list(
|
|
|
- sorted(
|
|
|
- (
|
|
|
- field_name,
|
|
|
- model_field.name,
|
|
|
- _serialize_type(model_field.type_),
|
|
|
- (
|
|
|
- model_field.required
|
|
|
- if isinstance(model_field.required, bool)
|
|
|
- else None
|
|
|
- ),
|
|
|
- (model_field.default if is_serializable(model_field.default) else None),
|
|
|
- )
|
|
|
- for field_name, model_field in state.__fields__.items()
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
def reset_disk_state_manager():
|
|
|
"""Reset the disk state manager."""
|
|
|
states_directory = prerequisites.get_web_dir() / constants.Dirs.STATES
|