Forráskód Böngészése

Reduce pickle size (#4063)

* Only serialize base vars
* Never serialize router/router_data in substates
* Hash the schema to reduce serialized size
* lru_cache the schema to avoid recomputing it
Masen Furer 7 hónapja
szülő
commit
edd17208c0
2 módosított fájl, 47 hozzáadás és 35 törlés
  1. 44 32
      reflex/state.py
  2. 3 3
      tests/integration/test_dynamic_routes.py

+ 44 - 32
reflex/state.py

@@ -691,6 +691,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                     parent_state.get_parent_state(),
                 )
 
+        # Reset cached schema value
+        cls._to_schema.cache_clear()
+
     @classmethod
     def _check_overridden_methods(cls):
         """Check for shadow methods and raise error if any.
@@ -1945,20 +1948,58 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             The state dict for serialization.
         """
         state = super().__getstate__()
-        # Never serialize parent_state or substates
         state["__dict__"] = state["__dict__"].copy()
+        if state["__dict__"].get("parent_state") is not None:
+            # Do not serialize router data in substates (only the root state).
+            state["__dict__"].pop("router", None)
+            state["__dict__"].pop("router_data", None)
+        # Never serialize parent_state or substates.
         state["__dict__"]["parent_state"] = None
         state["__dict__"]["substates"] = {}
         state["__dict__"].pop("_was_touched", None)
+        # Remove all inherited vars.
+        for inherited_var_name in self.inherited_vars:
+            state["__dict__"].pop(inherited_var_name, None)
         return state
 
+    @classmethod
+    @functools.lru_cache()
+    def _to_schema(cls) -> str:
+        """Convert a state to a schema.
+
+        Returns:
+            The hash of the schema.
+        """
+
+        def _field_tuple(
+            field_name: str,
+        ) -> Tuple[str, str, Any, Union[bool, None], Any]:
+            model_field = cls.__fields__[field_name]
+            return (
+                field_name,
+                model_field.name,
+                _serialize_type(model_field.type_),
+                (
+                    model_field.required
+                    if isinstance(model_field.required, bool)
+                    else None
+                ),
+                (model_field.default if is_serializable(model_field.default) else None),
+            )
+
+        return md5(
+            pickle.dumps(
+                list(sorted(_field_tuple(field_name) for field_name in cls.base_vars))
+            )
+        ).hexdigest()
+
     def _serialize(self) -> bytes:
         """Serialize the state for redis.
 
         Returns:
             The serialized state.
         """
-        return pickle.dumps((state_to_schema(self), self))
+        return pickle.dumps((self._to_schema(), self))
 
     @classmethod
     def _deserialize(
@@ -1985,7 +2026,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             (substate_schema, state) = pickle.load(fp)
         else:
             raise ValueError("Only one of `data` or `fp` must be provided")
-        if substate_schema != state_to_schema(state):
+        if substate_schema != state._to_schema():
             raise StateSchemaMismatchError()
         return state
 
@@ -2620,35 +2661,6 @@ def is_serializable(value: Any) -> bool:
         return False
 
 
-def state_to_schema(
-    state: BaseState,
-) -> List[Tuple[str, str, Any, Union[bool, None], Any]]:
-    """Convert a state to a schema.
-
-    Args:
-        state: The state to convert to a schema.
-
-    Returns:
-        The schema.
-    """
-    return list(
-        sorted(
-            (
-                field_name,
-                model_field.name,
-                _serialize_type(model_field.type_),
-                (
-                    model_field.required
-                    if isinstance(model_field.required, bool)
-                    else None
-                ),
-                (model_field.default if is_serializable(model_field.default) else None),
-            )
-            for field_name, model_field in state.__fields__.items()
-        )
-    )
-
-
 def reset_disk_state_manager():
     """Reset the disk state manager."""
     states_directory = prerequisites.get_web_dir() / constants.Dirs.STATES

+ 3 - 3
tests/integration/test_dynamic_routes.py

@@ -41,13 +41,13 @@ def DynamicRoute():
         return rx.fragment(
             rx.input(
                 value=DynamicState.router.session.client_token,
-                is_read_only=True,
+                read_only=True,
                 id="token",
             ),
-            rx.input(value=rx.State.page_id, is_read_only=True, id="page_id"),  # type: ignore
+            rx.input(value=rx.State.page_id, read_only=True, id="page_id"),  # type: ignore
             rx.input(
                 value=DynamicState.router.page.raw_path,
-                is_read_only=True,
+                read_only=True,
                 id="raw_path",
             ),
             rx.link("index", href="/", id="link_index"),