浏览代码

[REF-2122] Opt-in multiprocess compile (#2838)

* Revert "Revert "Revert "Revert "use process pool to compile faster (#2377)" (#2434)" (#2497)" (#2595)"

This reverts commit 6b6eea4d7d3b3476738f26460524774adce3ca2b.

* Adjust number of operations for more correct progress bar

* app: recognize REFLEX_COMPILE_PROCESSES and REFLEX_COMPILE_THREADS

Control whether multiprocessing is used and the number of processes or threads
that should be used.

This will allow users to opt-in to the new, potentially hazardous,
multiprocessing mode, which results in much faster compiles, but has already
been reverted 4 times. Lets leave the code in this time, but use the thread
pool executor by default.

Limiting the number of threads or processes to 1 can also aid in debugging
issues that arise during compile time.

* Allow REFLEX_COMPILE_PROCESSES=0 to trigger multiprocessing with auto workers
Masen Furer 1 年之前
父节点
当前提交
ee1ff7f93f
共有 2 个文件被更改,包括 212 次插入93 次删除
  1. 125 91
      reflex/app.py
  2. 87 2
      reflex/compiler/compiler.py

+ 125 - 91
reflex/app.py

@@ -7,7 +7,9 @@ import concurrent.futures
 import contextlib
 import contextlib
 import copy
 import copy
 import functools
 import functools
+import multiprocessing
 import os
 import os
+import platform
 from typing import (
 from typing import (
     Any,
     Any,
     AsyncIterator,
     AsyncIterator,
@@ -37,6 +39,7 @@ from reflex.admin import AdminDash
 from reflex.base import Base
 from reflex.base import Base
 from reflex.compiler import compiler
 from reflex.compiler import compiler
 from reflex.compiler import utils as compiler_utils
 from reflex.compiler import utils as compiler_utils
+from reflex.compiler.compiler import ExecutorSafeFunctions
 from reflex.components import connection_modal, connection_pulser
 from reflex.components import connection_modal, connection_pulser
 from reflex.components.base.app_wrap import AppWrap
 from reflex.components.base.app_wrap import AppWrap
 from reflex.components.base.fragment import Fragment
 from reflex.components.base.fragment import Fragment
@@ -754,6 +757,17 @@ class App(Base):
             TimeElapsedColumn(),
             TimeElapsedColumn(),
         )
         )
 
 
+        # try to be somewhat accurate - but still not 100%
+        adhoc_steps_without_executor = 6
+        fixed_pages_within_executor = 5
+        progress.start()
+        task = progress.add_task(
+            "Compiling:",
+            total=len(self.pages)
+            + fixed_pages_within_executor
+            + adhoc_steps_without_executor,
+        )
+
         # Get the env mode.
         # Get the env mode.
         config = get_config()
         config = get_config()
 
 
@@ -769,6 +783,8 @@ class App(Base):
             # If a theme component was provided, wrap the app with it
             # If a theme component was provided, wrap the app with it
             app_wrappers[(20, "Theme")] = self.theme
             app_wrappers[(20, "Theme")] = self.theme
 
 
+        progress.advance(task)
+
         # Fix up the style.
         # Fix up the style.
         self.style = evaluate_style_namespaces(self.style)
         self.style = evaluate_style_namespaces(self.style)
 
 
@@ -776,138 +792,156 @@ class App(Base):
         all_imports = {}
         all_imports = {}
         custom_components = set()
         custom_components = set()
 
 
-        # Compile the pages in parallel.
-        with progress, concurrent.futures.ThreadPoolExecutor() as thread_pool:
-            fixed_pages = 7
-            task = progress.add_task("Compiling:", total=len(self.pages) + fixed_pages)
+        for _route, component in self.pages.items():
+            # Merge the component style with the app style.
+            component.add_style(self.style)
 
 
-            def mark_complete(_=None):
-                progress.advance(task)
+            component.apply_theme(self.theme)
 
 
-            for _route, component in self.pages.items():
-                # Merge the component style with the app style.
-                component.add_style(self.style)
+            # Add component.get_imports() to all_imports.
+            all_imports.update(component.get_imports())
 
 
-                component.apply_theme(self.theme)
+            # Add the app wrappers from this component.
+            app_wrappers.update(component.get_app_wrap_components())
 
 
-                # Add component.get_imports() to all_imports.
-                all_imports.update(component.get_imports())
+            # Add the custom components from the page to the set.
+            custom_components |= component.get_custom_components()
 
 
-                # Add the app wrappers from this component.
-                app_wrappers.update(component.get_app_wrap_components())
+        progress.advance(task)
 
 
-                # Add the custom components from the page to the set.
-                custom_components |= component.get_custom_components()
+        # Perform auto-memoization of stateful components.
+        (
+            stateful_components_path,
+            stateful_components_code,
+            page_components,
+        ) = compiler.compile_stateful_components(self.pages.values())
 
 
-            # Perform auto-memoization of stateful components.
-            (
-                stateful_components_path,
-                stateful_components_code,
-                page_components,
-            ) = compiler.compile_stateful_components(self.pages.values())
+        progress.advance(task)
 
 
-            # Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State.
-            if (
-                code_uses_state_contexts(stateful_components_code)
-                and self.state is None
-            ):
-                raise RuntimeError(
-                    "To access rx.State in frontend components, at least one "
-                    "subclass of rx.State must be defined in the app."
-                )
-            compile_results.append((stateful_components_path, stateful_components_code))
+        # Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State.
+        if code_uses_state_contexts(stateful_components_code) and self.state is None:
+            raise RuntimeError(
+                "To access rx.State in frontend components, at least one "
+                "subclass of rx.State must be defined in the app."
+            )
+        compile_results.append((stateful_components_path, stateful_components_code))
+
+        # Compile the root document before fork.
+        compile_results.append(
+            compiler.compile_document_root(
+                self.head_components,
+                html_lang=self.html_lang,
+                html_custom_attrs=self.html_custom_attrs,  # type: ignore
+            )
+        )
+
+        # Compile the contexts before fork.
+        compile_results.append(
+            compiler.compile_contexts(self.state, self.theme),
+        )
+
+        app_root = self._app_root(app_wrappers=app_wrappers)
+
+        progress.advance(task)
 
 
+        # Prepopulate the global ExecutorSafeFunctions class with input data required by the compile functions.
+        # This is required for multiprocessing to work, in presence of non-picklable inputs.
+        for route, component in zip(self.pages, page_components):
+            ExecutorSafeFunctions.COMPILE_PAGE_ARGS_BY_ROUTE[route] = (
+                route,
+                component,
+                self.state,
+            )
+
+        ExecutorSafeFunctions.COMPILE_APP_APP_ROOT = app_root
+        ExecutorSafeFunctions.CUSTOM_COMPONENTS = custom_components
+        ExecutorSafeFunctions.STYLE = self.style
+
+        # Use a forking process pool, if possible.  Much faster, especially for large sites.
+        # Fallback to ThreadPoolExecutor as something that will always work.
+        executor = None
+        if (
+            platform.system() in ("Linux", "Darwin")
+            and os.environ.get("REFLEX_COMPILE_PROCESSES") is not None
+        ):
+            executor = concurrent.futures.ProcessPoolExecutor(
+                max_workers=int(os.environ.get("REFLEX_COMPILE_PROCESSES", 0)) or None,
+                mp_context=multiprocessing.get_context("fork"),
+            )
+        else:
+            executor = concurrent.futures.ThreadPoolExecutor(
+                max_workers=int(os.environ.get("REFLEX_COMPILE_THREADS", 0)) or None,
+            )
+
+        with executor:
             result_futures = []
             result_futures = []
             custom_components_future = None
             custom_components_future = None
 
 
-            def submit_work(fn, *args, **kwargs):
-                """Submit work to the thread pool and add a callback to mark the task as complete.
-
-                The Future will be added to the `result_futures` list.
+            def _mark_complete(_=None):
+                progress.advance(task)
 
 
-                Args:
-                    fn: The function to submit.
-                    *args: The args to submit.
-                    **kwargs: The kwargs to submit.
-                """
-                f = thread_pool.submit(fn, *args, **kwargs)
-                f.add_done_callback(mark_complete)
+            def _submit_work(fn, *args, **kwargs):
+                f = executor.submit(fn, *args, **kwargs)
+                f.add_done_callback(_mark_complete)
                 result_futures.append(f)
                 result_futures.append(f)
 
 
             # Compile all page components.
             # Compile all page components.
-            for route, component in zip(self.pages, page_components):
-                submit_work(
-                    compiler.compile_page,
-                    route,
-                    component,
-                    self.state,
-                )
+            for route in self.pages:
+                _submit_work(ExecutorSafeFunctions.compile_page, route)
 
 
             # Compile the app wrapper.
             # Compile the app wrapper.
-            app_root = self._app_root(app_wrappers=app_wrappers)
-            submit_work(compiler.compile_app, app_root)
+            _submit_work(ExecutorSafeFunctions.compile_app)
 
 
             # Compile the custom components.
             # Compile the custom components.
-            custom_components_future = thread_pool.submit(
-                compiler.compile_components, custom_components
+            custom_components_future = executor.submit(
+                ExecutorSafeFunctions.compile_custom_components,
             )
             )
-            custom_components_future.add_done_callback(mark_complete)
+            custom_components_future.add_done_callback(_mark_complete)
 
 
             # Compile the root stylesheet with base styles.
             # Compile the root stylesheet with base styles.
-            submit_work(compiler.compile_root_stylesheet, self.stylesheets)
-
-            # Compile the root document.
-            submit_work(
-                compiler.compile_document_root,
-                self.head_components,
-                html_lang=self.html_lang,
-                html_custom_attrs=self.html_custom_attrs,
-            )
+            _submit_work(compiler.compile_root_stylesheet, self.stylesheets)
 
 
             # Compile the theme.
             # Compile the theme.
-            submit_work(compiler.compile_theme, style=self.style)
-
-            # Compile the contexts.
-            submit_work(compiler.compile_contexts, self.state, self.theme)
+            _submit_work(ExecutorSafeFunctions.compile_theme)
 
 
             # Compile the Tailwind config.
             # Compile the Tailwind config.
             if config.tailwind is not None:
             if config.tailwind is not None:
                 config.tailwind["content"] = config.tailwind.get(
                 config.tailwind["content"] = config.tailwind.get(
                     "content", constants.Tailwind.CONTENT
                     "content", constants.Tailwind.CONTENT
                 )
                 )
-                submit_work(compiler.compile_tailwind, config.tailwind)
+                _submit_work(compiler.compile_tailwind, config.tailwind)
             else:
             else:
-                submit_work(compiler.remove_tailwind_from_postcss)
-
-            # Get imports from AppWrap components.
-            all_imports.update(app_root.get_imports())
+                _submit_work(compiler.remove_tailwind_from_postcss)
 
 
             # Wait for all compilation tasks to complete.
             # Wait for all compilation tasks to complete.
             for future in concurrent.futures.as_completed(result_futures):
             for future in concurrent.futures.as_completed(result_futures):
                 compile_results.append(future.result())
                 compile_results.append(future.result())
 
 
-            # Iterate through all the custom components and add their imports to the all_imports.
-            custom_components_result = custom_components_future.result()
-            compile_results.append(custom_components_result[:2])
-            all_imports.update(custom_components_result[2])
+            # Special case for custom_components, since we need the compiled imports
+            # to install proper frontend packages.
+            (
+                *custom_components_result,
+                custom_components_imports,
+            ) = custom_components_future.result()
+            compile_results.append(custom_components_result)
+            all_imports.update(custom_components_imports)
 
 
-            # Empty the .web pages directory.
-            compiler.purge_web_pages_dir()
+        # Get imports from AppWrap components.
+        all_imports.update(app_root.get_imports())
 
 
-            # Avoid flickering when installing frontend packages
-            progress.stop()
+        progress.advance(task)
 
 
-            # Install frontend packages.
-            self.get_frontend_packages(all_imports)
+        # Empty the .web pages directory.
+        compiler.purge_web_pages_dir()
 
 
-            # Write the pages at the end to trigger the NextJS hot reload only once.
-            write_page_futures = []
-            for output_path, code in compile_results:
-                write_page_futures.append(
-                    thread_pool.submit(compiler_utils.write_page, output_path, code)
-                )
-            for future in concurrent.futures.as_completed(write_page_futures):
-                future.result()
+        progress.advance(task)
+        progress.stop()
+
+        # Install frontend packages.
+        self.get_frontend_packages(all_imports)
+
+        for output_path, code in compile_results:
+            compiler_utils.write_page(output_path, code)
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:

+ 87 - 2
reflex/compiler/compiler.py

@@ -67,7 +67,7 @@ def _compile_theme(theme: dict) -> str:
     return templates.THEME.render(theme=theme)
     return templates.THEME.render(theme=theme)
 
 
 
 
-def _compile_contexts(state: Optional[Type[BaseState]], theme: Component) -> str:
+def _compile_contexts(state: Optional[Type[BaseState]], theme: Component | None) -> str:
     """Compile the initial state and contexts.
     """Compile the initial state and contexts.
 
 
     Args:
     Args:
@@ -368,7 +368,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]:
 
 
 def compile_contexts(
 def compile_contexts(
     state: Optional[Type[BaseState]],
     state: Optional[Type[BaseState]],
-    theme: Component,
+    theme: Component | None,
 ) -> tuple[str, str]:
 ) -> tuple[str, str]:
     """Compile the initial state / context.
     """Compile the initial state / context.
 
 
@@ -494,3 +494,88 @@ def purge_web_pages_dir():
 
 
     # Empty out the web pages directory.
     # Empty out the web pages directory.
     utils.empty_dir(constants.Dirs.WEB_PAGES, keep_files=["_app.js"])
     utils.empty_dir(constants.Dirs.WEB_PAGES, keep_files=["_app.js"])
+
+
+class ExecutorSafeFunctions:
+    """Helper class to allow parallelisation of parts of the compilation process.
+
+    This class (and its class attributes) are available at global scope.
+
+    In a multiprocessing context (like when using a ProcessPoolExecutor), the content of this
+    global class is logically replicated to any FORKED process.
+
+    How it works:
+    * Before the child process is forked, ensure that we stash any input data required by any future
+      function call in the child process.
+    * After the child process is forked, the child process will have a copy of the global class, which
+      includes the previously stashed input data.
+    * Any task submitted to the child process simply needs a way to communicate which input data the
+      requested function call requires.
+
+    Why do we need this? Passing input data directly to child process often not possible because the input data is not picklable.
+    The mechanic described here removes the need to pickle the input data at all.
+
+    Limitations:
+    * This can never support returning unpicklable OUTPUT data.
+    * Any object mutations done by the child process will not propagate back to the parent process (fork goes one way!).
+
+    """
+
+    COMPILE_PAGE_ARGS_BY_ROUTE = {}
+    COMPILE_APP_APP_ROOT: Component | None = None
+    CUSTOM_COMPONENTS: set[CustomComponent] | None = None
+    STYLE: ComponentStyle | None = None
+
+    @classmethod
+    def compile_page(cls, route: str):
+        """Compile a page.
+
+        Args:
+            route: The route of the page to compile.
+
+        Returns:
+            The path and code of the compiled page.
+        """
+        return compile_page(*cls.COMPILE_PAGE_ARGS_BY_ROUTE[route])
+
+    @classmethod
+    def compile_app(cls):
+        """Compile the app.
+
+        Returns:
+            The path and code of the compiled app.
+
+        Raises:
+            ValueError: If the app root is not set.
+        """
+        if cls.COMPILE_APP_APP_ROOT is None:
+            raise ValueError("COMPILE_APP_APP_ROOT should be set")
+        return compile_app(cls.COMPILE_APP_APP_ROOT)
+
+    @classmethod
+    def compile_custom_components(cls):
+        """Compile the custom components.
+
+        Returns:
+            The path and code of the compiled custom components.
+
+        Raises:
+            ValueError: If the custom components are not set.
+        """
+        if cls.CUSTOM_COMPONENTS is None:
+            raise ValueError("CUSTOM_COMPONENTS should be set")
+        return compile_components(cls.CUSTOM_COMPONENTS)
+
+    @classmethod
+    def compile_theme(cls):
+        """Compile the theme.
+
+        Returns:
+            The path and code of the compiled theme.
+
+        Raises:
+            ValueError: If the style is not set.
+        """
+        if cls.STYLE is None:
+            raise ValueError("STYLE should be set")
+        return compile_theme(cls.STYLE)