瀏覽代碼

[ENG-4134]Allow specifying custom app module in rxconfig (#4556)

* Allow custom app module in rxconfig

* what was that pyscopg mess?

* fix another mess

* get this working with relative imports and hot reload

* typing to named tuple

* minor refactor

* revert redis knobs positions

* fix pyright except 1

* fix pyright hopefully

* use the resolved module path

* testing workflow

* move nba-proxy job to counter job

* just cast the type

* fix tests for python 3.9

* darglint

* CR Suggestions for #4556 (#4644)

* reload_dirs: search up from app_module for last directory containing __init__

* Change custom app_module to use an import string

* preserve sys.path entries added while loading rxconfig.py

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Elijah Ahianyo 3 月之前
父節點
當前提交
268effe62e
共有 7 個文件被更改,包括 132 次插入30 次删除
  1. 21 1
      .github/workflows/integration_tests.yml
  2. 3 4
      reflex/app_module_for_backend.py
  3. 27 3
      reflex/config.py
  4. 1 1
      reflex/event.py
  5. 15 14
      reflex/state.py
  6. 24 2
      reflex/utils/exec.py
  7. 41 5
      reflex/utils/prerequisites.py

+ 21 - 1
.github/workflows/integration_tests.yml

@@ -33,7 +33,7 @@ env:
   PR_TITLE: ${{ github.event.pull_request.title }}
   PR_TITLE: ${{ github.event.pull_request.title }}
 
 
 jobs:
 jobs:
-  example-counter:
+  example-counter-and-nba-proxy:
     env:
     env:
       OUTPUT_FILE: import_benchmark.json
       OUTPUT_FILE: import_benchmark.json
     timeout-minutes: 30
     timeout-minutes: 30
@@ -119,6 +119,26 @@ jobs:
           --benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}"
           --benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}"
           --branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}"
           --branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}"
           --app-name "counter"
           --app-name "counter"
+      - name: Install requirements for nba proxy example
+        working-directory: ./reflex-examples/nba-proxy
+        run: |
+          poetry run uv pip install -r requirements.txt
+      - name: Install additional dependencies for DB access
+        run: poetry run uv pip install psycopg
+      - name: Check export --backend-only before init for nba-proxy example
+        working-directory: ./reflex-examples/nba-proxy
+        run: |
+          poetry run reflex export --backend-only
+      - name: Init Website for nba-proxy example
+        working-directory: ./reflex-examples/nba-proxy
+        run: |
+          poetry run reflex init --loglevel debug
+      - name: Run Website and Check for errors
+        run: |
+          # Check that npm is home
+          npm -v
+          poetry run bash scripts/integration.sh ./reflex-examples/nba-proxy dev
+
 
 
   reflex-web:
   reflex-web:
     strategy:
     strategy:

+ 3 - 4
reflex/app_module_for_backend.py

@@ -7,14 +7,13 @@ from concurrent.futures import ThreadPoolExecutor
 from reflex import constants
 from reflex import constants
 from reflex.utils import telemetry
 from reflex.utils import telemetry
 from reflex.utils.exec import is_prod_mode
 from reflex.utils.exec import is_prod_mode
-from reflex.utils.prerequisites import get_app
+from reflex.utils.prerequisites import get_and_validate_app
 
 
 if constants.CompileVars.APP != "app":
 if constants.CompileVars.APP != "app":
     raise AssertionError("unexpected variable name for 'app'")
     raise AssertionError("unexpected variable name for 'app'")
 
 
 telemetry.send("compile")
 telemetry.send("compile")
-app_module = get_app(reload=False)
-app = getattr(app_module, constants.CompileVars.APP)
+app, app_module = get_and_validate_app(reload=False)
 # For py3.9 compatibility when redis is used, we MUST add any decorator pages
 # For py3.9 compatibility when redis is used, we MUST add any decorator pages
 # before compiling the app in a thread to avoid event loop error (REF-2172).
 # before compiling the app in a thread to avoid event loop error (REF-2172).
 app._apply_decorated_pages()
 app._apply_decorated_pages()
@@ -30,7 +29,7 @@ if is_prod_mode():
 # ensure only "app" is exposed.
 # ensure only "app" is exposed.
 del app_module
 del app_module
 del compile_future
 del compile_future
-del get_app
+del get_and_validate_app
 del is_prod_mode
 del is_prod_mode
 del telemetry
 del telemetry
 del constants
 del constants

+ 27 - 3
reflex/config.py

@@ -12,6 +12,7 @@ import threading
 import urllib.parse
 import urllib.parse
 from importlib.util import find_spec
 from importlib.util import find_spec
 from pathlib import Path
 from pathlib import Path
+from types import ModuleType
 from typing import (
 from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
     Any,
     Any,
@@ -607,6 +608,9 @@ class Config(Base):
     # The name of the app (should match the name of the app directory).
     # The name of the app (should match the name of the app directory).
     app_name: str
     app_name: str
 
 
+    # The path to the app module.
+    app_module_import: Optional[str] = None
+
     # The log level to use.
     # The log level to use.
     loglevel: constants.LogLevel = constants.LogLevel.DEFAULT
     loglevel: constants.LogLevel = constants.LogLevel.DEFAULT
 
 
@@ -729,6 +733,19 @@ class Config(Base):
                 "REDIS_URL is required when using the redis state manager."
                 "REDIS_URL is required when using the redis state manager."
             )
             )
 
 
+    @property
+    def app_module(self) -> ModuleType | None:
+        """Return the app module if `app_module_import` is set.
+
+        Returns:
+            The app module.
+        """
+        return (
+            importlib.import_module(self.app_module_import)
+            if self.app_module_import
+            else None
+        )
+
     @property
     @property
     def module(self) -> str:
     def module(self) -> str:
         """Get the module name of the app.
         """Get the module name of the app.
@@ -736,6 +753,8 @@ class Config(Base):
         Returns:
         Returns:
             The module name.
             The module name.
         """
         """
+        if self.app_module is not None:
+            return self.app_module.__name__
         return ".".join([self.app_name, self.app_name])
         return ".".join([self.app_name, self.app_name])
 
 
     def update_from_env(self) -> dict[str, Any]:
     def update_from_env(self) -> dict[str, Any]:
@@ -874,7 +893,7 @@ def get_config(reload: bool = False) -> Config:
             return cached_rxconfig.config
             return cached_rxconfig.config
 
 
     with _config_lock:
     with _config_lock:
-        sys_path = sys.path.copy()
+        orig_sys_path = sys.path.copy()
         sys.path.clear()
         sys.path.clear()
         sys.path.append(str(Path.cwd()))
         sys.path.append(str(Path.cwd()))
         try:
         try:
@@ -882,9 +901,14 @@ def get_config(reload: bool = False) -> Config:
             return _get_config()
             return _get_config()
         except Exception:
         except Exception:
             # If the module import fails, try to import with the original sys.path.
             # If the module import fails, try to import with the original sys.path.
-            sys.path.extend(sys_path)
+            sys.path.extend(orig_sys_path)
             return _get_config()
             return _get_config()
         finally:
         finally:
+            # Find any entries added to sys.path by rxconfig.py itself.
+            extra_paths = [
+                p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd())
+            ]
             # Restore the original sys.path.
             # Restore the original sys.path.
             sys.path.clear()
             sys.path.clear()
-            sys.path.extend(sys_path)
+            sys.path.extend(extra_paths)
+            sys.path.extend(orig_sys_path)

+ 1 - 1
reflex/event.py

@@ -1591,7 +1591,7 @@ def get_handler_args(
 
 
 
 
 def fix_events(
 def fix_events(
-    events: list[EventHandler | EventSpec] | None,
+    events: list[EventSpec | EventHandler] | None,
     token: str,
     token: str,
     router_data: dict[str, Any] | None = None,
     router_data: dict[str, Any] | None = None,
 ) -> list[Event]:
 ) -> list[Event]:

+ 15 - 14
reflex/state.py

@@ -1776,9 +1776,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         except Exception as ex:
         except Exception as ex:
             state._clean()
             state._clean()
 
 
-            app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
-
-            event_specs = app_instance.backend_exception_handler(ex)
+            event_specs = (
+                prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
+            )
 
 
             if event_specs is None:
             if event_specs is None:
                 return StateUpdate()
                 return StateUpdate()
@@ -1888,9 +1888,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         except Exception as ex:
         except Exception as ex:
             telemetry.send_error(ex, context="backend")
             telemetry.send_error(ex, context="backend")
 
 
-            app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
-
-            event_specs = app_instance.backend_exception_handler(ex)
+            event_specs = (
+                prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
+            )
 
 
             yield state._as_state_update(
             yield state._as_state_update(
                 handler,
                 handler,
@@ -2403,8 +2403,9 @@ class FrontendEventExceptionState(State):
             component_stack: The stack trace of the component where the exception occurred.
             component_stack: The stack trace of the component where the exception occurred.
 
 
         """
         """
-        app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
-        app_instance.frontend_exception_handler(Exception(stack))
+        prerequisites.get_and_validate_app().app.frontend_exception_handler(
+            Exception(stack)
+        )
 
 
 
 
 class UpdateVarsInternalState(State):
 class UpdateVarsInternalState(State):
@@ -2442,15 +2443,16 @@ class OnLoadInternalState(State):
             The list of events to queue for on load handling.
             The list of events to queue for on load handling.
         """
         """
         # Do not app._compile()!  It should be already compiled by now.
         # Do not app._compile()!  It should be already compiled by now.
-        app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
-        load_events = app.get_load_events(self.router.page.path)
+        load_events = prerequisites.get_and_validate_app().app.get_load_events(
+            self.router.page.path
+        )
         if not load_events:
         if not load_events:
             self.is_hydrated = True
             self.is_hydrated = True
             return  # Fast path for navigation with no on_load events defined.
             return  # Fast path for navigation with no on_load events defined.
         self.is_hydrated = False
         self.is_hydrated = False
         return [
         return [
             *fix_events(
             *fix_events(
-                load_events,
+                cast(list[Union[EventSpec, EventHandler]], load_events),
                 self.router.session.client_token,
                 self.router.session.client_token,
                 router_data=self.router_data,
                 router_data=self.router_data,
             ),
             ),
@@ -2609,7 +2611,7 @@ class StateProxy(wrapt.ObjectProxy):
         """
         """
         super().__init__(state_instance)
         super().__init__(state_instance)
         # compile is not relevant to backend logic
         # compile is not relevant to backend logic
-        self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+        self._self_app = prerequisites.get_and_validate_app().app
         self._self_substate_path = tuple(state_instance.get_full_name().split("."))
         self._self_substate_path = tuple(state_instance.get_full_name().split("."))
         self._self_actx = None
         self._self_actx = None
         self._self_mutable = False
         self._self_mutable = False
@@ -3702,8 +3704,7 @@ def get_state_manager() -> StateManager:
     Returns:
     Returns:
         The state manager.
         The state manager.
     """
     """
-    app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
-    return app.state_manager
+    return prerequisites.get_and_validate_app().app.state_manager
 
 
 
 
 class MutableProxy(wrapt.ObjectProxy):
 class MutableProxy(wrapt.ObjectProxy):

+ 24 - 2
reflex/utils/exec.py

@@ -240,6 +240,28 @@ def run_backend(
         run_uvicorn_backend(host, port, loglevel)
         run_uvicorn_backend(host, port, loglevel)
 
 
 
 
+def get_reload_dirs() -> list[str]:
+    """Get the reload directories for the backend.
+
+    Returns:
+        The reload directories for the backend.
+    """
+    config = get_config()
+    reload_dirs = [config.app_name]
+    if config.app_module is not None and config.app_module.__file__:
+        module_path = Path(config.app_module.__file__).resolve().parent
+        while module_path.parent.name:
+            for parent_file in module_path.parent.iterdir():
+                if parent_file == "__init__.py":
+                    # go up a level to find dir without `__init__.py`
+                    module_path = module_path.parent
+                    break
+            else:
+                break
+        reload_dirs.append(str(module_path))
+    return reload_dirs
+
+
 def run_uvicorn_backend(host, port, loglevel: LogLevel):
 def run_uvicorn_backend(host, port, loglevel: LogLevel):
     """Run the backend in development mode using Uvicorn.
     """Run the backend in development mode using Uvicorn.
 
 
@@ -256,7 +278,7 @@ def run_uvicorn_backend(host, port, loglevel: LogLevel):
         port=port,
         port=port,
         log_level=loglevel.value,
         log_level=loglevel.value,
         reload=True,
         reload=True,
-        reload_dirs=[get_config().app_name],
+        reload_dirs=get_reload_dirs(),
     )
     )
 
 
 
 
@@ -281,7 +303,7 @@ def run_granian_backend(host, port, loglevel: LogLevel):
             interface=Interfaces.ASGI,
             interface=Interfaces.ASGI,
             log_level=LogLevels(loglevel.value),
             log_level=LogLevels(loglevel.value),
             reload=True,
             reload=True,
-            reload_paths=[Path(get_config().app_name)],
+            reload_paths=get_reload_dirs(),
             reload_ignore_dirs=[".web"],
             reload_ignore_dirs=[".web"],
         ).serve()
         ).serve()
     except ImportError:
     except ImportError:

+ 41 - 5
reflex/utils/prerequisites.py

@@ -17,11 +17,12 @@ import stat
 import sys
 import sys
 import tempfile
 import tempfile
 import time
 import time
+import typing
 import zipfile
 import zipfile
 from datetime import datetime
 from datetime import datetime
 from pathlib import Path
 from pathlib import Path
 from types import ModuleType
 from types import ModuleType
-from typing import Callable, List, Optional
+from typing import Callable, List, NamedTuple, Optional
 
 
 import httpx
 import httpx
 import typer
 import typer
@@ -42,9 +43,19 @@ from reflex.utils.exceptions import (
 from reflex.utils.format import format_library_name
 from reflex.utils.format import format_library_name
 from reflex.utils.registry import _get_npm_registry
 from reflex.utils.registry import _get_npm_registry
 
 
+if typing.TYPE_CHECKING:
+    from reflex.app import App
+
 CURRENTLY_INSTALLING_NODE = False
 CURRENTLY_INSTALLING_NODE = False
 
 
 
 
+class AppInfo(NamedTuple):
+    """A tuple containing the app instance and module."""
+
+    app: App
+    module: ModuleType
+
+
 @dataclasses.dataclass(frozen=True)
 @dataclasses.dataclass(frozen=True)
 class Template:
 class Template:
     """A template for a Reflex app."""
     """A template for a Reflex app."""
@@ -291,8 +302,11 @@ def get_app(reload: bool = False) -> ModuleType:
             )
             )
         module = config.module
         module = config.module
         sys.path.insert(0, str(Path.cwd()))
         sys.path.insert(0, str(Path.cwd()))
-        app = __import__(module, fromlist=(constants.CompileVars.APP,))
-
+        app = (
+            __import__(module, fromlist=(constants.CompileVars.APP,))
+            if not config.app_module
+            else config.app_module
+        )
         if reload:
         if reload:
             from reflex.state import reload_state_module
             from reflex.state import reload_state_module
 
 
@@ -308,6 +322,29 @@ def get_app(reload: bool = False) -> ModuleType:
         raise
         raise
 
 
 
 
+def get_and_validate_app(reload: bool = False) -> AppInfo:
+    """Get the app instance based on the default config and validate it.
+
+    Args:
+        reload: Re-import the app module from disk
+
+    Returns:
+        The app instance and the app module.
+
+    Raises:
+        RuntimeError: If the app instance is not an instance of rx.App.
+    """
+    from reflex.app import App
+
+    app_module = get_app(reload=reload)
+    app = getattr(app_module, constants.CompileVars.APP)
+    if not isinstance(app, App):
+        raise RuntimeError(
+            "The app instance in the specified app_module_import in rxconfig must be an instance of rx.App."
+        )
+    return AppInfo(app=app, module=app_module)
+
+
 def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
 def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
     """Get the app module based on the default config after first compiling it.
     """Get the app module based on the default config after first compiling it.
 
 
@@ -318,8 +355,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
     Returns:
     Returns:
         The compiled app based on the default config.
         The compiled app based on the default config.
     """
     """
-    app_module = get_app(reload=reload)
-    app = getattr(app_module, constants.CompileVars.APP)
+    app, app_module = get_and_validate_app(reload=reload)
     # For py3.9 compatibility when redis is used, we MUST add any decorator pages
     # For py3.9 compatibility when redis is used, we MUST add any decorator pages
     # before compiling the app in a thread to avoid event loop error (REF-2172).
     # before compiling the app in a thread to avoid event loop error (REF-2172).
     app._apply_decorated_pages()
     app._apply_decorated_pages()