Переглянути джерело

[ENG-4713] Cache pages which add states when evaluating (#4788)

* cache order of imports that create BaseState subclasses

* Track which pages create State subclasses during evaluation

These need to be replayed on the backend to ensure state alignment.

* Clean up: use constants, remove unused code

Handle closing files with contextmanager

* Expose app.add_all_routes_endpoint for flexgen

* Include .web/backend directory in backend.zip when exporting
Masen Furer 2 місяців тому
батько
коміт
deb1f4f702

+ 52 - 2
reflex/app.py

@@ -100,6 +100,7 @@ from reflex.state import (
     StateManager,
     StateUpdate,
     _substate_key,
+    all_base_state_classes,
     code_uses_state_contexts,
 )
 from reflex.utils import (
@@ -117,6 +118,7 @@ from reflex.utils.imports import ImportVar
 if TYPE_CHECKING:
     from reflex.vars import Var
 
+
 # Define custom types.
 ComponentCallable = Callable[[], Component]
 Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
@@ -375,6 +377,9 @@ class App(MiddlewareMixin, LifespanMixin):
     # A map from a page route to the component to render. Users should use `add_page`.
     _pages: Dict[str, Component] = dataclasses.field(default_factory=dict)
 
+    # A mapping of pages which created states as they were being evaluated.
+    _stateful_pages: Dict[str, None] = dataclasses.field(default_factory=dict)
+
     # The backend API object.
     _api: FastAPI | None = None
 
@@ -592,8 +597,10 @@ class App(MiddlewareMixin, LifespanMixin):
         """Add optional api endpoints (_upload)."""
         if not self.api:
             return
-
-        if Upload.is_used:
+        upload_is_used_marker = (
+            prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED
+        )
+        if Upload.is_used or upload_is_used_marker.exists():
             # To upload files.
             self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
 
@@ -603,10 +610,15 @@ class App(MiddlewareMixin, LifespanMixin):
                 StaticFiles(directory=get_upload_dir()),
                 name="uploaded_files",
             )
+
+            upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
+            upload_is_used_marker.touch()
         if codespaces.is_running_in_codespaces():
             self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
                 codespaces.auth_codespace
             )
+        if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
+            self.add_all_routes_endpoint()
 
     def _add_cors(self):
         """Add CORS middleware to the app."""
@@ -747,13 +759,19 @@ class App(MiddlewareMixin, LifespanMixin):
             route: The route of the page to compile.
             save_page: If True, the compiled page is saved to self._pages.
         """
+        n_states_before = len(all_base_state_classes)
         component, enable_state = compiler.compile_unevaluated_page(
             route, self._unevaluated_pages[route], self._state, self.style, self.theme
         )
 
+        # Indicate that the app should use state.
         if enable_state:
             self._enable_state()
 
+        # Indicate that evaluating this page creates one or more state classes.
+        if len(all_base_state_classes) > n_states_before:
+            self._stateful_pages[route] = None
+
         # Add the page.
         self._check_routes_conflict(route)
         if save_page:
@@ -1042,6 +1060,20 @@ class App(MiddlewareMixin, LifespanMixin):
         def get_compilation_time() -> str:
             return str(datetime.now().time()).split(".")[0]
 
+        should_compile = self._should_compile()
+        backend_dir = prerequisites.get_backend_dir()
+        if not should_compile and backend_dir.exists():
+            stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES
+            if stateful_pages_marker.exists():
+                with stateful_pages_marker.open("r") as f:
+                    stateful_pages = json.load(f)
+                for route in stateful_pages:
+                    console.info(f"BE Evaluating stateful page: {route}")
+                    self._compile_page(route, save_page=False)
+                self._enable_state()
+            self._add_optional_endpoints()
+            return
+
         # Render a default 404 page if the user didn't supply one
         if constants.Page404.SLUG not in self._unevaluated_pages:
             self.add_page(route=constants.Page404.SLUG)
@@ -1343,6 +1375,24 @@ class App(MiddlewareMixin, LifespanMixin):
             for output_path, code in compile_results:
                 compiler_utils.write_page(output_path, code)
 
+        # Write list of routes that create dynamic states for backend to use.
+        if self._state is not None:
+            stateful_pages_marker = (
+                prerequisites.get_backend_dir() / constants.Dirs.STATEFUL_PAGES
+            )
+            stateful_pages_marker.parent.mkdir(parents=True, exist_ok=True)
+            with stateful_pages_marker.open("w") as f:
+                json.dump(list(self._stateful_pages), f)
+
+    def add_all_routes_endpoint(self):
+        """Add an endpoint to the app that returns all the routes."""
+        if not self.api:
+            return
+
+        @self.api.get(str(constants.Endpoint.ALL_ROUTES))
+        async def all_routes():
+            return list(self._unevaluated_pages.keys())
+
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state out of band.

+ 3 - 0
reflex/config.py

@@ -713,6 +713,9 @@ class EnvironmentVariables:
     # Paths to exclude from the hot reload. Takes precedence over include paths. Separated by a colon.
     REFLEX_HOT_RELOAD_EXCLUDE_PATHS: EnvVar[List[Path]] = env_var([])
 
+    # Used by flexgen to enumerate the pages.
+    REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False)
+
 
 environment = EnvironmentVariables()
 

+ 6 - 0
reflex/constants/base.py

@@ -53,6 +53,12 @@ class Dirs(SimpleNamespace):
     POSTCSS_JS = "postcss.config.js"
     # The name of the states directory.
     STATES = ".states"
+    # Where compilation artifacts for the backend are stored.
+    BACKEND = "backend"
+    # JSON-encoded list of page routes that need to be evaluated on the backend.
+    STATEFUL_PAGES = "stateful_pages.json"
+    # Marker file indicating that upload component was used in the frontend.
+    UPLOAD_IS_USED = "upload_is_used"
 
 
 class Reflex(SimpleNamespace):

+ 1 - 0
reflex/constants/event.py

@@ -12,6 +12,7 @@ class Endpoint(Enum):
     UPLOAD = "_upload"
     AUTH_CODESPACE = "auth-codespace"
     HEALTH = "_health"
+    ALL_ROUTES = "_all_routes"
 
     def __str__(self) -> str:
         """Get the string representation of the endpoint.

+ 6 - 0
reflex/state.py

@@ -327,6 +327,9 @@ async def _resolve_delta(delta: Delta) -> Delta:
     return delta
 
 
+all_base_state_classes: dict[str, None] = {}
+
+
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
 
@@ -624,6 +627,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         cls._var_dependencies = {}
         cls._init_var_dependency_dicts()
 
+        all_base_state_classes[cls.get_full_name()] = None
+
     @staticmethod
     def _copy_fn(fn: Callable) -> Callable:
         """Copy a function. Used to copy ComputedVars and EventHandlers from mixins.
@@ -4087,6 +4092,7 @@ def reload_state_module(
     for subclass in tuple(state.class_subclasses):
         reload_state_module(module=module, state=subclass)
         if subclass.__module__ == module and module is not None:
+            all_base_state_classes.pop(subclass.get_full_name(), None)
             state.class_subclasses.remove(subclass)
             state._always_dirty_substates.discard(subclass.get_name())
             state._var_dependencies = {}

+ 12 - 0
reflex/utils/build.py

@@ -60,6 +60,7 @@ def _zip(
     dirs_to_exclude: set[str] | None = None,
     files_to_exclude: set[str] | None = None,
     top_level_dirs_to_exclude: set[str] | None = None,
+    globs_to_include: list[str] | None = None,
 ) -> None:
     """Zip utility function.
 
@@ -72,6 +73,7 @@ def _zip(
         dirs_to_exclude: The directories to exclude.
         files_to_exclude: The files to exclude.
         top_level_dirs_to_exclude: The top level directory names immediately under root_dir to exclude. Do not exclude folders by these names further in the sub-directories.
+        globs_to_include: Apply these globs from the root_dir and always include them in the zip.
 
     """
     target = Path(target)
@@ -103,6 +105,13 @@ def _zip(
         files_to_zip += [
             str(root / file) for file in files if file not in files_to_exclude
         ]
+    if globs_to_include:
+        for glob in globs_to_include:
+            files_to_zip += [
+                str(file)
+                for file in root_dir.glob(glob)
+                if file.name not in files_to_exclude
+            ]
 
     # Create a progress bar for zipping the component.
     progress = Progress(
@@ -160,6 +169,9 @@ def zip_app(
             top_level_dirs_to_exclude={"assets"},
             exclude_venv_dirs=True,
             upload_db_file=upload_db_file,
+            globs_to_include=[
+                str(Path(constants.Dirs.WEB) / constants.Dirs.BACKEND / "*")
+            ],
         )
 
 

+ 9 - 0
reflex/utils/prerequisites.py

@@ -99,6 +99,15 @@ def get_states_dir() -> Path:
     return environment.REFLEX_STATES_WORKDIR.get()
 
 
+def get_backend_dir() -> Path:
+    """Get the working directory for the backend.
+
+    Returns:
+        The working directory.
+    """
+    return get_web_dir() / constants.Dirs.BACKEND
+
+
 def check_latest_package_version(package_name: str):
     """Check if the latest version of the package is installed.