Ver código fonte

WiP - pickle dynamic states to bring backend up faster

Masen Furer 3 meses atrás
pai
commit
52d98b125a
2 arquivos alterados com 88 adições e 2 exclusões
  1. 83 2
      reflex/app.py
  2. 5 0
      reflex/vars/base.py

+ 83 - 2
reflex/app.py

@@ -13,11 +13,12 @@ import io
 import json
 import json
 import multiprocessing
 import multiprocessing
 import platform
 import platform
+import shutil
 import sys
 import sys
 import traceback
 import traceback
 from datetime import datetime
 from datetime import datetime
 from pathlib import Path
 from pathlib import Path
-from types import SimpleNamespace
+from types import FunctionType, SimpleNamespace
 from typing import (
 from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
     Any,
     Any,
@@ -39,11 +40,13 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile
 from fastapi.middleware import cors
 from fastapi.middleware import cors
 from fastapi.responses import JSONResponse, StreamingResponse
 from fastapi.responses import JSONResponse, StreamingResponse
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
+from rich.console import ConsoleThreadLocals
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
 from starlette_admin.contrib.sqla.admin import Admin
 from starlette_admin.contrib.sqla.admin import Admin
 from starlette_admin.contrib.sqla.view import ModelView
 from starlette_admin.contrib.sqla.view import ModelView
 
 
+import reflex.istate.dynamic
 from reflex import constants
 from reflex import constants
 from reflex.admin import AdminDash
 from reflex.admin import AdminDash
 from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
 from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
@@ -97,10 +100,34 @@ from reflex.state import (
 from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
 from reflex.utils import codespaces, console, exceptions, format, prerequisites, types
 from reflex.utils.exec import is_prod_mode, is_testing_env
 from reflex.utils.exec import is_prod_mode, is_testing_env
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
+from reflex.vars.base import ComputedVar
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.vars import Var
     from reflex.vars import Var
 
 
+try:
+    import dill
+except ImportError:
+    dill = None
+else:
+    # Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
+    if not isinstance(State.validate.__func__, FunctionType):
+        import builtins
+
+        cython_function_or_method = type(State.validate.__func__)
+        builtins.cython_function_or_method = cython_function_or_method
+
+        @dill.register(cython_function_or_method)
+        def _dill_reduce_cython_function_or_method(pickler, obj):
+            # Ignore cython function when pickling.
+            pass
+
+    @dill.register(ConsoleThreadLocals)
+    def _dill_reduce_console_thread_locals(pickler, obj):
+        # Ignore console thread locals when pickling.
+        pass
+
+
 # Define custom types.
 # Define custom types.
 ComponentCallable = Callable[[], Component]
 ComponentCallable = Callable[[], Component]
 Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
 Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
@@ -337,6 +364,11 @@ class App(MiddlewareMixin, LifespanMixin):
         if not self.state:
         if not self.state:
             self.state = State
             self.state = State
             self._setup_state()
             self._setup_state()
+            enable_state_marker = (
+                prerequisites.get_web_dir() / "backend" / "enable_state"
+            )
+            enable_state_marker.parent.mkdir(parents=True, exist_ok=True)
+            enable_state_marker.touch()
 
 
     def _setup_state(self) -> None:
     def _setup_state(self) -> None:
         """Set up the state for the app.
         """Set up the state for the app.
@@ -415,7 +447,10 @@ class App(MiddlewareMixin, LifespanMixin):
 
 
     def _add_optional_endpoints(self):
     def _add_optional_endpoints(self):
         """Add optional api endpoints (_upload)."""
         """Add optional api endpoints (_upload)."""
-        if Upload.is_used:
+        upload_is_used_marker = (
+            prerequisites.get_web_dir() / "backend" / "upload_is_used"
+        )
+        if Upload.is_used or upload_is_used_marker.exists():
             # To upload files.
             # To upload files.
             self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
             self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
 
 
@@ -425,6 +460,9 @@ class App(MiddlewareMixin, LifespanMixin):
                 StaticFiles(directory=get_upload_dir()),
                 StaticFiles(directory=get_upload_dir()),
                 name="uploaded_files",
                 name="uploaded_files",
             )
             )
+
+            upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
+            upload_is_used_marker.touch()
         if codespaces.is_running_in_codespaces():
         if codespaces.is_running_in_codespaces():
             self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
             self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
                 codespaces.auth_codespace
                 codespaces.auth_codespace
@@ -856,6 +894,18 @@ class App(MiddlewareMixin, LifespanMixin):
         def get_compilation_time() -> str:
         def get_compilation_time() -> str:
             return str(datetime.now().time()).split(".")[0]
             return str(datetime.now().time()).split(".")[0]
 
 
+        should_compile = self._should_compile()
+        backend_dir = prerequisites.get_web_dir() / "backend"
+        if not should_compile and backend_dir.exists():
+            enable_state_marker = backend_dir / "enable_state"
+            if enable_state_marker.exists():
+                self._enable_state()
+            pickle_states_root = backend_dir / "states"
+            if pickle_states_root.exists():
+                self._unpickle_dynamic_states(pickle_states_root)
+            self._add_optional_endpoints()
+            return
+
         # Render a default 404 page if the user didn't supply one
         # Render a default 404 page if the user didn't supply one
         if constants.Page404.SLUG not in self.unevaluated_pages:
         if constants.Page404.SLUG not in self.unevaluated_pages:
             self.add_page(route=constants.Page404.SLUG)
             self.add_page(route=constants.Page404.SLUG)
@@ -1077,6 +1127,37 @@ class App(MiddlewareMixin, LifespanMixin):
         for output_path, code in compile_results:
         for output_path, code in compile_results:
             compiler_utils.write_page(output_path, code)
             compiler_utils.write_page(output_path, code)
 
 
+        # Pickle dynamic states
+        if self.state is not None and dill is not None:
+            pickle_dir = prerequisites.get_web_dir() / "backend" / "states"
+            if pickle_dir.exists():
+                shutil.rmtree(pickle_dir)
+            pickle_dir.mkdir(parents=True, exist_ok=True)
+            unfuck_states = []
+            for state in reflex.istate.dynamic.__dict__.values():
+                if isinstance(state, type) and issubclass(state, self.state):
+                    unfuck_states.append(state)
+                    object.__setattr__(state.setvar, "state_cls", None)
+            ComputedVar._is_pickling = True
+            try:
+                dill.session.dump_session(
+                    filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic
+                )
+            except TypeError:
+                with dill.detect.trace():
+                    dill.session.dump_session(
+                        filename=pickle_dir / "dynamic.pkl", main=reflex.istate.dynamic
+                    )
+            ComputedVar._is_pickling = False
+            for state in unfuck_states:
+                object.__setattr__(state.setvar, "state_cls", state)
+
+    def _unpickle_dynamic_states(self, root: Path):
+        if dill is None:
+            raise ImportError("dill is required to unpickle dynamic states")
+        for pk_file in sorted(root.iterdir()):
+            dill.session.load_session(filename=pk_file, main=reflex.istate.dynamic)
+
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state out of band.
         """Modify the state out of band.

+ 5 - 0
reflex/vars/base.py

@@ -1834,6 +1834,9 @@ class ComputedVar(Var[RETURN_TYPE]):
         default_factory=lambda: lambda _: None
         default_factory=lambda: lambda _: None
     )  # type: ignore
     )  # type: ignore
 
 
+    # Flag determines whether we are pickling the computed var itself
+    _is_pickling: ClassVar[bool] = False
+
     def __init__(
     def __init__(
         self,
         self,
         fget: Callable[[BASE_STATE], RETURN_TYPE],
         fget: Callable[[BASE_STATE], RETURN_TYPE],
@@ -2227,6 +2230,8 @@ class ComputedVar(Var[RETURN_TYPE]):
         Returns:
         Returns:
             The class of the var.
             The class of the var.
         """
         """
+        if self._is_pickling:
+            return type(self)
         return FakeComputedVarBaseClass
         return FakeComputedVarBaseClass
 
 
     @property
     @property