Prechádzať zdrojové kódy

fix lifespan tasks regression (#5218)

* fix lifespan tasks regression

* fix lifespan issue even when transformer is used

* add_cors to top asgi app

* test_lifespan with FastAPI and api_transformer used

* avoid test warnings when _state_manager is not initialized

* call .app_instance() in the correct directory

* Call the app_instance to get the asgi object during initialization

* revert unnecessary chdir scope

* revert more unnecessary changes

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Thomas Brandého 2 týždňov pred
rodič
commit
bb7e73d76e

+ 7 - 2
reflex/app.py

@@ -488,7 +488,7 @@ class App(MiddlewareMixin, LifespanMixin):
             set_breakpoints(self.style.pop("breakpoints"))
 
         # Set up the API.
-        self._api = Starlette(lifespan=self._run_lifespan_tasks)
+        self._api = Starlette()
         App._add_cors(self._api)
         self._add_default_endpoints()
 
@@ -629,6 +629,7 @@ class App(MiddlewareMixin, LifespanMixin):
 
         if not self._api:
             raise ValueError("The app has not been initialized.")
+
         if self._cached_fastapi_app is not None:
             asgi_app = self._cached_fastapi_app
             asgi_app.mount("", self._api)
@@ -653,7 +654,11 @@ class App(MiddlewareMixin, LifespanMixin):
                     # Transform the asgi app.
                     asgi_app = api_transformer(asgi_app)
 
-        return asgi_app
+        top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks)
+        top_asgi_app.mount("", asgi_app)
+        App._add_cors(top_asgi_app)
+
+        return top_asgi_app
 
     def _add_default_endpoints(self):
         """Add default api endpoints (ping)."""

+ 17 - 12
reflex/testing.py

@@ -45,6 +45,7 @@ from reflex.state import (
 )
 from reflex.utils import console
 from reflex.utils.export import export
+from reflex.utils.types import ASGIApp
 
 try:
     from selenium import webdriver
@@ -110,6 +111,7 @@ class AppHarness:
     app_module_path: Path
     app_module: types.ModuleType | None = None
     app_instance: reflex.App | None = None
+    app_asgi: ASGIApp | None = None
     frontend_process: subprocess.Popen | None = None
     frontend_url: str | None = None
     frontend_output_thread: threading.Thread | None = None
@@ -270,11 +272,14 @@ class AppHarness:
             # Ensure the AppHarness test does not skip State assignment due to running via pytest
             os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
             os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
-            self.app_module = reflex.utils.prerequisites.get_compiled_app(
-                # Do not reload the module for pre-existing apps (only apps generated from source)
-                reload=self.app_source is not None
+            # Ensure we actually compile the app during first initialization.
+            self.app_instance, self.app_module = (
+                reflex.utils.prerequisites.get_and_validate_app(
+                    # Do not reload the module for pre-existing apps (only apps generated from source)
+                    reload=self.app_source is not None
+                )
             )
-        self.app_instance = self.app_module.app
+            self.app_asgi = self.app_instance()
         if self.app_instance and isinstance(
             self.app_instance._state_manager, StateManagerRedis
         ):
@@ -300,10 +305,10 @@ class AppHarness:
         async def _shutdown(*args, **kwargs) -> None:
             # ensure redis is closed before event loop
             if self.app_instance is not None and isinstance(
-                self.app_instance.state_manager, StateManagerRedis
+                self.app_instance._state_manager, StateManagerRedis
             ):
                 with contextlib.suppress(ValueError):
-                    await self.app_instance.state_manager.close()
+                    await self.app_instance._state_manager.close()
 
             # socketio shutdown handler
             if self.app_instance is not None and self.app_instance.sio is not None:
@@ -323,11 +328,11 @@ class AppHarness:
         return _shutdown
 
     def _start_backend(self, port: int = 0):
-        if self.app_instance is None or self.app_instance._api is None:
+        if self.app_asgi is None:
             raise RuntimeError("App was not initialized.")
         self.backend = uvicorn.Server(
             uvicorn.Config(
-                app=self.app_instance._api,
+                app=self.app_asgi,
                 host="127.0.0.1",
                 port=port,
             )
@@ -349,13 +354,13 @@ class AppHarness:
         if (
             self.app_instance is not None
             and isinstance(
-                self.app_instance.state_manager,
+                self.app_instance._state_manager,
                 StateManagerRedis,
             )
             and self.app_instance._state is not None
         ):
             with contextlib.suppress(RuntimeError):
-                await self.app_instance.state_manager.close()
+                await self.app_instance._state_manager.close()
             self.app_instance._state_manager = StateManagerRedis.create(
                 state=self.app_instance._state,
             )
@@ -959,12 +964,12 @@ class AppHarnessProd(AppHarness):
             raise RuntimeError("Frontend did not start")
 
     def _start_backend(self):
-        if self.app_instance is None:
+        if self.app_asgi is None:
             raise RuntimeError("App was not initialized.")
         environment.REFLEX_SKIP_COMPILE.set(True)
         self.backend = uvicorn.Server(
             uvicorn.Config(
-                app=self.app_instance,
+                app=self.app_asgi,
                 host="127.0.0.1",
                 port=0,
                 workers=reflex.utils.processes.get_num_workers(),

+ 1 - 2
tests/integration/test_connection_banner.py

@@ -1,6 +1,5 @@
 """Test case for displaying the connection banner when the websocket drops."""
 
-import functools
 from collections.abc import Generator
 
 import pytest
@@ -77,7 +76,7 @@ def connection_banner(
 
     with AppHarness.create(
         root=tmp_path,
-        app_source=functools.partial(ConnectionBanner),
+        app_source=ConnectionBanner,
         app_name=(
             "connection_banner_reflex_cloud"
             if simulate_compile_context == constants.CompileContext.DEPLOY

+ 57 - 6
tests/integration/test_lifespan.py

@@ -1,5 +1,6 @@
 """Test cases for the Starlette lifespan integration."""
 
+import functools
 from collections.abc import Generator
 
 import pytest
@@ -10,8 +11,15 @@ from reflex.testing import AppHarness
 from .utils import SessionStorage
 
 
-def LifespanApp():
-    """App with lifespan tasks and context."""
+def LifespanApp(
+    mount_cached_fastapi: bool = False, mount_api_transformer: bool = False
+) -> None:
+    """App with lifespan tasks and context.
+
+    Args:
+        mount_cached_fastapi: Whether to mount the cached FastAPI app.
+        mount_api_transformer: Whether to mount the API transformer.
+    """
     import asyncio
     from contextlib import asynccontextmanager
 
@@ -72,25 +80,68 @@ def LifespanApp():
             ),
         )
 
-    app = rx.App()
+    from fastapi import FastAPI
+
+    app = rx.App(api_transformer=FastAPI() if mount_api_transformer else None)
+
+    if mount_cached_fastapi:
+        assert app.api is not None
+
     app.register_lifespan_task(lifespan_task)
     app.register_lifespan_task(lifespan_context, inc=2)
     app.add_page(index)
 
 
+@pytest.fixture(
+    params=[False, True], ids=["no_api_transformer", "mount_api_transformer"]
+)
+def mount_api_transformer(request: pytest.FixtureRequest) -> bool:
+    """Whether to use api_transformer in the app.
+
+    Args:
+        request: pytest fixture request object
+
+    Returns:
+        bool: Whether to use api_transformer
+    """
+    return request.param
+
+
+@pytest.fixture(params=[False, True], ids=["no_fastapi", "mount_cached_fastapi"])
+def mount_cached_fastapi(request: pytest.FixtureRequest) -> bool:
+    """Whether to use cached FastAPI in the app (app.api).
+
+    Args:
+        request: pytest fixture request object
+
+    Returns:
+        Whether to use cached FastAPI
+    """
+    return request.param
+
+
 @pytest.fixture()
-def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]:
+def lifespan_app(
+    tmp_path, mount_api_transformer: bool, mount_cached_fastapi: bool
+) -> Generator[AppHarness, None, None]:
     """Start LifespanApp app at tmp_path via AppHarness.
 
     Args:
         tmp_path: pytest tmp_path fixture
+        mount_api_transformer: Whether to mount the API transformer.
+        mount_cached_fastapi: Whether to mount the cached FastAPI app.
 
     Yields:
         running AppHarness instance
     """
     with AppHarness.create(
         root=tmp_path,
-        app_source=LifespanApp,
+        app_source=functools.partial(
+            LifespanApp,
+            mount_cached_fastapi=mount_cached_fastapi,
+            mount_api_transformer=mount_api_transformer,
+        ),
+        app_name=f"lifespanapp_fastapi{mount_cached_fastapi}_transformer{mount_api_transformer}",
     ) as harness:
         yield harness
 
@@ -112,7 +163,7 @@ async def test_lifespan(lifespan_app: AppHarness):
     context_global = driver.find_element(By.ID, "context_global")
     task_global = driver.find_element(By.ID, "task_global")
 
-    assert context_global.text == "2"
+    assert lifespan_app.poll_for_content(context_global, exp_not_equal="0") == "2"
     assert lifespan_app.app_module.lifespan_context_global == 2
 
     original_task_global_text = task_global.text