Sfoglia il codice sorgente

WiP - pickle dynamic states to bring backend up faster

Masen Furer 3 mesi fa
parent
commit
52d98b125a
2 ha cambiato i file con 88 aggiunte e 2 eliminazioni
  1. 83 2
      reflex/app.py
  2. 5 0
      reflex/vars/base.py

+ 83 - 2
reflex/app.py

@@ -13,11 +13,12 @@ import io
 import json
 import multiprocessing
 import platform
+import shutil
 import sys
 import traceback
 from datetime import datetime
 from pathlib import Path
-from types import SimpleNamespace
+from types import FunctionType, SimpleNamespace
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -39,11 +40,13 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile
 from fastapi.middleware import cors
 from fastapi.responses import JSONResponse, StreamingResponse
 from fastapi.staticfiles import StaticFiles
+from rich.console import ConsoleThreadLocals
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
 from starlette_admin.contrib.sqla.admin import Admin
 from starlette_admin.contrib.sqla.view import ModelView
 
+import reflex.istate.dynamic
 from reflex import constants
 from reflex.admin import AdminDash
 from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
@@ -97,10 +100,34 @@ from reflex.state import (
 from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
 from reflex.utils.exec import is_prod_mode, is_testing_env
 from reflex.utils.imports import ImportVar
+from reflex.vars.base import ComputedVar
 
 if TYPE_CHECKING:
     from reflex.vars import Var
 
+try:
+    import dill
+except ImportError:
+    dill = None
+else:
+    # Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
+    if not isinstance(State.validate.__func__, FunctionType):
+        import builtins
+
+        cython_function_or_method = type(State.validate.__func__)
+        builtins.cython_function_or_method = cython_function_or_method
+
+        @dill.register(cython_function_or_method)
+        def _dill_reduce_cython_function_or_method(pickler, obj):
+            # Ignore cython function when pickling.
+            pass
+
+    @dill.register(ConsoleThreadLocals)
+    def _dill_reduce_console_thread_locals(pickler, obj):
+        # Ignore console thread locals when pickling.
+        pass
+
+
 # Define custom types.
 ComponentCallable = Callable[[], Component]
 Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
@@ -337,6 +364,11 @@ class App(MiddlewareMixin, LifespanMixin):
         if not self.state:
             self.state = State
             self._setup_state()
+            enable_state_marker = (
+                prerequisites.get_web_dir() / "backend" / "enable_state"
+            )
+            enable_state_marker.parent.mkdir(parents=True, exist_ok=True)
+            enable_state_marker.touch()
 
     def _setup_state(self) -> None:
         """Set up the state for the app.
@@ -415,7 +447,10 @@ class App(MiddlewareMixin, LifespanMixin):
 
     def _add_optional_endpoints(self):
         """Add optional api endpoints (_upload)."""
-        if Upload.is_used:
+        upload_is_used_marker = (
+            prerequisites.get_web_dir() / "backend" / "upload_is_used"
+        )
+        if Upload.is_used or upload_is_used_marker.exists():
             # To upload files.
             self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
 
@@ -425,6 +460,9 @@ class App(MiddlewareMixin, LifespanMixin):
                 StaticFiles(directory=get_upload_dir()),
                 name="uploaded_files",
             )
+
+            upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
+            upload_is_used_marker.touch()
         if codespaces.is_running_in_codespaces():
             self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
                 codespaces.auth_codespace
@@ -856,6 +894,18 @@ class App(MiddlewareMixin, LifespanMixin):
         def get_compilation_time() -> str:
             return str(datetime.now().time()).split(".")[0]
 
+        should_compile = self._should_compile()
+        backend_dir = prerequisites.get_web_dir() / "backend"
+        if not should_compile and backend_dir.exists():
+            enable_state_marker = backend_dir / "enable_state"
+            if enable_state_marker.exists():
+                self._enable_state()
+            pickle_states_root = backend_dir / "states"
+            if pickle_states_root.exists():
+                self._unpickle_dynamic_states(pickle_states_root)
+            self._add_optional_endpoints()
+            return
+
         # Render a default 404 page if the user didn't supply one
         if constants.Page404.SLUG not in self.unevaluated_pages:
             self.add_page(route=constants.Page404.SLUG)
@@ -1077,6 +1127,37 @@ class App(MiddlewareMixin, LifespanMixin):
         for output_path, code in compile_results:
             compiler_utils.write_page(output_path, code)
 
+        # Pickle dynamic states
+        if self.state is not None and dill is not None:
+            pickle_dir = prerequisites.get_web_dir() / "backend" / "states"
+            if pickle_dir.exists():
+                shutil.rmtree(pickle_dir)
+            pickle_dir.mkdir(parents=True, exist_ok=True)
+            unfuck_states = []
+            for state in reflex.istate.dynamic.__dict__.values():
+                if isinstance(state, type) and issubclass(state, self.state):
+                    unfuck_states.append(state)
+                    object.__setattr__(state.setvar, "state_cls", None)
+            ComputedVar._is_pickling = True
+            try:
+                dill.session.dump_session(
+                    filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic
+                )
+            except TypeError:
+                with dill.detect.trace():
+                    dill.session.dump_session(
+                        filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic
+                    )
+            ComputedVar._is_pickling = False
+            for state in unfuck_states:
+                object.__setattr__(state.setvar, "state_cls", state)
+
+    def _unpickle_dynamic_states(self, root: Path):
+        if dill is None:
+            raise ImportError("dill is required to unpickle dynamic states")
+        for pk_file in sorted(root.iterdir()):
+            dill.session.load_session(filename=pk_file, main=reflex.istate.dynamic)
+
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state out of band.

+ 5 - 0
reflex/vars/base.py

@@ -1834,6 +1834,9 @@ class ComputedVar(Var[RETURN_TYPE]):
         default_factory=lambda: lambda _: None
     )  # type: ignore
 
+    # Flag determines whether we are pickling the computed var itself
+    _is_pickling: ClassVar[bool] = False
+
     def __init__(
         self,
         fget: Callable[[BASE_STATE], RETURN_TYPE],
@@ -2227,6 +2230,8 @@ class ComputedVar(Var[RETURN_TYPE]):
         Returns:
             The class of the var.
         """
+        if self._is_pickling:
+            return type(self)
         return FakeComputedVarBaseClass
 
     @property