Explorar el Código

[ENG-3867] Garden Variety Pickle (#4054)

* Use regular `pickle` module from stdlib

* Avoid recreating the rx.State tree for every `get_state`

* Remove dill dependency

* relock deps
Masen Furer hace 7 meses
padre
commit
d77b900bd7
Se han modificado 4 ficheros con 103 adiciones y 85 borrados
  1. 43 53
      poetry.lock
  2. 0 1
      pyproject.toml
  3. 56 31
      reflex/state.py
  4. 4 0
      reflex/utils/exceptions.py

+ 43 - 53
poetry.lock

@@ -516,21 +516,6 @@ files = [
     {file = "darglint-1.8.1.tar.gz", hash = "sha256:080d5106df149b199822e7ee7deb9c012b49891538f14a11be681044f0bb20da"},
 ]
 
-[[package]]
-name = "dill"
-version = "0.3.8"
-description = "serialize all of Python"
-optional = false
-python-versions = ">=3.8"
-files = [
-    {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
-    {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
-]
-
-[package.extras]
-graph = ["objgraph (>=1.7.2)"]
-profile = ["gprof2dot (>=2022.7.29)"]
-
 [[package]]
 name = "distlib"
 version = "0.3.8"
@@ -719,13 +704,13 @@ files = [
 
 [[package]]
 name = "httpcore"
-version = "1.0.5"
+version = "1.0.6"
 description = "A minimal low-level HTTP client."
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"},
-    {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"},
+    {file = "httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f"},
+    {file = "httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f"},
 ]
 
 [package.dependencies]
@@ -736,7 +721,7 @@ h11 = ">=0.13,<0.15"
 asyncio = ["anyio (>=4.0,<5.0)"]
 http2 = ["h2 (>=3,<5)"]
 socks = ["socksio (==1.*)"]
-trio = ["trio (>=0.22.0,<0.26.0)"]
+trio = ["trio (>=0.22.0,<1.0)"]
 
 [[package]]
 name = "httpx"
@@ -863,21 +848,25 @@ test = ["portend", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-c
 
 [[package]]
 name = "jaraco-functools"
-version = "4.0.2"
+version = "4.1.0"
 description = "Functools like those found in stdlib"
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "jaraco.functools-4.0.2-py3-none-any.whl", hash = "sha256:c9d16a3ed4ccb5a889ad8e0b7a343401ee5b2a71cee6ed192d3f68bc351e94e3"},
-    {file = "jaraco_functools-4.0.2.tar.gz", hash = "sha256:3460c74cd0d32bf82b9576bbb3527c4364d5b27a21f5158a62aed6c4b42e23f5"},
+    {file = "jaraco.functools-4.1.0-py3-none-any.whl", hash = "sha256:ad159f13428bc4acbf5541ad6dec511f91573b90fba04df61dafa2a1231cf649"},
+    {file = "jaraco_functools-4.1.0.tar.gz", hash = "sha256:70f7e0e2ae076498e212562325e805204fc092d7b4c17e0e86c959e249701a9d"},
 ]
 
 [package.dependencies]
 more-itertools = "*"
 
 [package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"]
+cover = ["pytest-cov"]
 doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
-test = ["jaraco.classes", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
+enabler = ["pytest-enabler (>=2.2)"]
+test = ["jaraco.classes", "pytest (>=6,!=8.1.*)"]
+type = ["pytest-mypy"]
 
 [[package]]
 name = "jeepney"
@@ -1788,13 +1777,13 @@ windows-terminal = ["colorama (>=0.4.6)"]
 
 [[package]]
 name = "pyproject-hooks"
-version = "1.1.0"
+version = "1.2.0"
 description = "Wrappers to call pyproject.toml-based build backend hooks."
 optional = false
 python-versions = ">=3.7"
 files = [
-    {file = "pyproject_hooks-1.1.0-py3-none-any.whl", hash = "sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2"},
-    {file = "pyproject_hooks-1.1.0.tar.gz", hash = "sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965"},
+    {file = "pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913"},
+    {file = "pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8"},
 ]
 
 [[package]]
@@ -1992,13 +1981,13 @@ docs = ["sphinx"]
 
 [[package]]
 name = "python-multipart"
-version = "0.0.10"
+version = "0.0.12"
 description = "A streaming multipart parser for Python"
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "python_multipart-0.0.10-py3-none-any.whl", hash = "sha256:2b06ad9e8d50c7a8db80e3b56dab590137b323410605af2be20d62a5f1ba1dc8"},
-    {file = "python_multipart-0.0.10.tar.gz", hash = "sha256:46eb3c6ce6fdda5fb1a03c7e11d490e407c6930a2703fe7aef4da71c374688fa"},
+    {file = "python_multipart-0.0.12-py3-none-any.whl", hash = "sha256:43dcf96cf65888a9cd3423544dd0d75ac10f7aa0c3c28a175bbcd00c9ce1aebf"},
+    {file = "python_multipart-0.0.12.tar.gz", hash = "sha256:045e1f98d719c1ce085ed7f7e1ef9d8ccc8c02ba02b5566d5f7521410ced58cb"},
 ]
 
 [[package]]
@@ -2143,31 +2132,31 @@ md = ["cmarkgfm (>=0.8.0)"]
 
 [[package]]
 name = "redis"
-version = "5.0.8"
+version = "5.1.0"
 description = "Python client for Redis database and key-value store"
 optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
 files = [
-    {file = "redis-5.0.8-py3-none-any.whl", hash = "sha256:56134ee08ea909106090934adc36f65c9bcbbaecea5b21ba704ba6fb561f8eb4"},
-    {file = "redis-5.0.8.tar.gz", hash = "sha256:0c5b10d387568dfe0698c6fad6615750c24170e548ca2deac10c649d463e9870"},
+    {file = "redis-5.1.0-py3-none-any.whl", hash = "sha256:fd4fccba0d7f6aa48c58a78d76ddb4afc698f5da4a2c1d03d916e4fd7ab88cdd"},
+    {file = "redis-5.1.0.tar.gz", hash = "sha256:b756df1e4a3858fcc0ef861f3fc53623a96c41e2b1f5304e09e0fe758d333d40"},
 ]
 
 [package.dependencies]
 async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
 
 [package.extras]
-hiredis = ["hiredis (>1.0.0)"]
-ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
+hiredis = ["hiredis (>=3.0.0)"]
+ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
 
 [[package]]
 name = "reflex-chakra"
-version = "0.6.0"
+version = "0.6.1"
 description = "reflex using chakra components"
 optional = false
 python-versions = "<4.0,>=3.8"
 files = [
-    {file = "reflex_chakra-0.6.0-py3-none-any.whl", hash = "sha256:eca1593fca67289e05591dd21fbcc8632c119d64a08bdc41fd995055a114cc91"},
-    {file = "reflex_chakra-0.6.0.tar.gz", hash = "sha256:db1c7b48f1ba547bf91e5af103fce6fc7191d7225b414ebfbada7d983e33dd87"},
+    {file = "reflex_chakra-0.6.1-py3-none-any.whl", hash = "sha256:824d461264b6d2c836ba4a2a430e677a890b82e83da149672accfc58786442fa"},
+    {file = "reflex_chakra-0.6.1.tar.gz", hash = "sha256:4b9b3c8bada19cbb4d1b8d8bc4ab0460ec008a91f380010c34d416d5b613dc07"},
 ]
 
 [package.dependencies]
@@ -2247,18 +2236,19 @@ idna2008 = ["idna"]
 
 [[package]]
 name = "rich"
-version = "13.8.1"
+version = "13.9.1"
 description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
 optional = false
-python-versions = ">=3.7.0"
+python-versions = ">=3.8.0"
 files = [
-    {file = "rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06"},
-    {file = "rich-13.8.1.tar.gz", hash = "sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a"},
+    {file = "rich-13.9.1-py3-none-any.whl", hash = "sha256:b340e739f30aa58921dc477b8adaa9ecdb7cecc217be01d93730ee1bc8aa83be"},
+    {file = "rich-13.9.1.tar.gz", hash = "sha256:097cffdf85db1babe30cc7deba5ab3a29e1b9885047dab24c57e9a7f8a9c1466"},
 ]
 
 [package.dependencies]
 markdown-it-py = ">=2.2.0"
 pygments = ">=2.13.0,<3.0.0"
+typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""}
 
 [package.extras]
 jupyter = ["ipywidgets (>=7.5.1,<9)"]
@@ -2595,13 +2585,13 @@ files = [
 
 [[package]]
 name = "tomli"
-version = "2.0.1"
+version = "2.0.2"
 description = "A lil' TOML parser"
 optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
 files = [
-    {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
-    {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
+    {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"},
+    {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
 ]
 
 [[package]]
@@ -2734,13 +2724,13 @@ zstd = ["zstandard (>=0.18.0)"]
 
 [[package]]
 name = "uvicorn"
-version = "0.30.6"
+version = "0.31.0"
 description = "The lightning-fast ASGI server."
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "uvicorn-0.30.6-py3-none-any.whl", hash = "sha256:65fd46fe3fda5bdc1b03b94eb634923ff18cd35b2f084813ea79d1f103f711b5"},
-    {file = "uvicorn-0.30.6.tar.gz", hash = "sha256:4b15decdda1e72be08209e860a1e10e92439ad5b97cf44cc945fcbee66fc5788"},
+    {file = "uvicorn-0.31.0-py3-none-any.whl", hash = "sha256:cac7be4dd4d891c363cd942160a7b02e69150dcbc7a36be04d5f4af4b17c8ced"},
+    {file = "uvicorn-0.31.0.tar.gz", hash = "sha256:13bc21373d103859f68fe739608e2eb054a816dea79189bc3ca08ea89a275906"},
 ]
 
 [package.dependencies]
@@ -2753,13 +2743,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)",
 
 [[package]]
 name = "virtualenv"
-version = "20.26.5"
+version = "20.26.6"
 description = "Virtual Python Environment builder"
 optional = false
 python-versions = ">=3.7"
 files = [
-    {file = "virtualenv-20.26.5-py3-none-any.whl", hash = "sha256:4f3ac17b81fba3ce3bd6f4ead2749a72da5929c01774948e243db9ba41df4ff6"},
-    {file = "virtualenv-20.26.5.tar.gz", hash = "sha256:ce489cac131aa58f4b25e321d6d186171f78e6cb13fafbf32a840cee67733ff4"},
+    {file = "virtualenv-20.26.6-py3-none-any.whl", hash = "sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2"},
+    {file = "virtualenv-20.26.6.tar.gz", hash = "sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48"},
 ]
 
 [package.dependencies]
@@ -3011,4 +3001,4 @@ type = ["pytest-mypy"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.9"
-content-hash = "adccd071775567aeefe219261aeb9e222906c865745f03edb1e770edc79c44ac"
+content-hash = "e4b462ebfae90550ba7fa49b360d7110c0d344ee616c23989c22d866ef8f6f31"

+ 0 - 1
pyproject.toml

@@ -27,7 +27,6 @@ packages = [
 
 [tool.poetry.dependencies]
 python = "^3.9"
-dill = ">=0.3.8,<0.4"
 fastapi = ">=0.96.0,!=0.111.0,!=0.111.1"
 gunicorn = ">=20.1.0,<24.0"
 jinja2 = ">=3.1.2,<4.0"

+ 56 - 31
reflex/state.py

@@ -9,6 +9,7 @@ import dataclasses
 import functools
 import inspect
 import os
+import pickle
 import uuid
 from abc import ABC, abstractmethod
 from collections import defaultdict
@@ -19,6 +20,7 @@ from typing import (
     TYPE_CHECKING,
     Any,
     AsyncIterator,
+    BinaryIO,
     Callable,
     ClassVar,
     Dict,
@@ -33,7 +35,6 @@ from typing import (
     get_type_hints,
 )
 
-import dill
 from sqlalchemy.orm import DeclarativeBase
 from typing_extensions import Self
 
@@ -76,6 +77,7 @@ from reflex.utils.exceptions import (
     ImmutableStateError,
     LockExpiredError,
     SetUndefinedStateVarError,
+    StateSchemaMismatchError,
 )
 from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import serializer
@@ -1914,7 +1916,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     def __getstate__(self):
         """Get the state for redis serialization.
 
-        This method is called by cloudpickle to serialize the object.
+        This method is called by pickle to serialize the object.
 
         It explicitly removes parent_state and substates because those are serialized separately
         by the StateManagerRedis to allow for better horizontal scaling as state size increases.
@@ -1930,6 +1932,43 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         state["__dict__"].pop("_was_touched", None)
         return state
 
+    def _serialize(self) -> bytes:
+        """Serialize the state for redis.
+
+        Returns:
+            The serialized state.
+        """
+        return pickle.dumps((state_to_schema(self), self))
+
+    @classmethod
+    def _deserialize(
+        cls, data: bytes | None = None, fp: BinaryIO | None = None
+    ) -> BaseState:
+        """Deserialize the state from redis/disk.
+
+        data and fp are mutually exclusive, but one must be provided.
+
+        Args:
+            data: The serialized state data.
+            fp: The file pointer to the serialized state data.
+
+        Returns:
+            The deserialized state.
+
+        Raises:
+            ValueError: If both data and fp are provided, or neither are provided.
+            StateSchemaMismatchError: If the state schema does not match the expected schema.
+        """
+        if data is not None and fp is None:
+            (substate_schema, state) = pickle.loads(data)
+        elif fp is not None and data is None:
+            (substate_schema, state) = pickle.load(fp)
+        else:
+            raise ValueError("Only one of `data` or `fp` must be provided")
+        if substate_schema != state_to_schema(state):
+            raise StateSchemaMismatchError()
+        return state
+
 
 class State(BaseState):
     """The app Base State."""
@@ -2086,7 +2125,11 @@ class ComponentState(State, mixin=True):
         """
         cls._per_component_state_instance_count += 1
         state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
-        component_state = type(state_cls_name, (cls, State), {}, mixin=False)
+        component_state = type(
+            state_cls_name, (cls, State), {"__module__": __name__}, mixin=False
+        )
+        # Save a reference to the dynamic state for pickle/unpickle.
+        globals()[state_cls_name] = component_state
         component = component_state.get_component(*children, **props)
         component.State = component_state
         return component
@@ -2552,7 +2595,7 @@ def is_serializable(value: Any) -> bool:
         Whether the value is serializable.
     """
     try:
-        return bool(dill.dumps(value))
+        return bool(pickle.dumps(value))
     except Exception:
         return False
 
@@ -2688,8 +2731,7 @@ class StateManagerDisk(StateManager):
         if token_path.exists():
             try:
                 with token_path.open(mode="rb") as file:
-                    (substate_schema, substate) = dill.load(file)
-                if substate_schema == state_to_schema(substate):
+                    substate = BaseState._deserialize(fp=file)
                     await self.populate_substates(client_token, substate, root_state)
                     return substate
             except Exception:
@@ -2731,10 +2773,12 @@ class StateManagerDisk(StateManager):
         client_token, substate_address = _split_substate_key(token)
 
         root_state_token = _substate_key(client_token, substate_address.split(".")[0])
+        root_state = self.states.get(root_state_token)
+        if root_state is None:
+            # Create a new root state which will be persisted in the next set_state call.
+            root_state = self.state(_reflex_internal_init=True)
 
-        return await self.load_state(
-            root_state_token, self.state(_reflex_internal_init=True)
-        )
+        return await self.load_state(root_state_token, root_state)
 
     async def set_state_for_substate(self, client_token: str, substate: BaseState):
         """Set the state for a substate.
@@ -2747,7 +2791,7 @@ class StateManagerDisk(StateManager):
 
         self.states[substate_token] = substate
 
-        state_dilled = dill.dumps((state_to_schema(substate), substate))
+        state_dilled = substate._serialize()
         if not self.states_directory.exists():
             self.states_directory.mkdir(parents=True, exist_ok=True)
         self.token_path(substate_token).write_bytes(state_dilled)
@@ -2790,25 +2834,6 @@ class StateManagerDisk(StateManager):
             await self.set_state(token, state)
 
 
-# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
-if not isinstance(State.validate.__func__, FunctionType):
-    cython_function_or_method = type(State.validate.__func__)
-
-    @dill.register(cython_function_or_method)
-    def _dill_reduce_cython_function_or_method(pickler, obj):
-        # Ignore cython function when pickling.
-        pass
-
-
-@dill.register(type(State))
-def _dill_reduce_state(pickler, obj):
-    if obj is not State and issubclass(obj, State):
-        # Avoid serializing subclasses of State, instead get them by reference from the State class.
-        pickler.save_reduce(State.get_class_substate, (obj.get_full_name(),), obj=obj)
-    else:
-        dill.Pickler.dispatch[type](pickler, obj)
-
-
 def _default_lock_expiration() -> int:
     """Get the default lock expiration time.
 
@@ -2948,7 +2973,7 @@ class StateManagerRedis(StateManager):
 
         if redis_state is not None:
             # Deserialize the substate.
-            state = dill.loads(redis_state)
+            state = BaseState._deserialize(data=redis_state)
 
             # Populate parent state if missing and requested.
             if parent_state is None:
@@ -3060,7 +3085,7 @@ class StateManagerRedis(StateManager):
             )
         # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
         if state._get_was_touched():
-            pickle_state = dill.dumps(state, byref=True)
+            pickle_state = state._serialize()
             self._warn_if_too_large(state, len(pickle_state))
             await self.redis.set(
                 _substate_key(client_token, state),

+ 4 - 0
reflex/utils/exceptions.py

@@ -123,3 +123,7 @@ class DynamicComponentMissingLibrary(ReflexError, ValueError):
 
 class SetUndefinedStateVarError(ReflexError, AttributeError):
     """Raised when setting the value of a var without first declaring it."""
+
+
+class StateSchemaMismatchError(ReflexError, TypeError):
+    """Raised when the serialized schema of a state class does not match the current schema."""