Browse Source

delay api mounting until app finishes compile (#5184)

* delay api mounting until app finishes compile

* mock _compile

* add cors

* add cors to asgi app

* dang it darglint

* refactor code to make it more readable thanks to @masenf
Khaleel Al-Adhami 3 weeks ago
parent
commit
ab019970c1
2 changed files with 22 additions and 22 deletions
  1. 21 22
      reflex/app.py
  2. 1 0
      tests/units/test_app.py

+ 21 - 22
reflex/app.py

@@ -489,7 +489,7 @@ class App(MiddlewareMixin, LifespanMixin):
 
         # Set up the API.
         self._api = Starlette(lifespan=self._run_lifespan_tasks)
-        self._add_cors()
+        App._add_cors(self._api)
         self._add_default_endpoints()
 
         for clz in App.__mro__:
@@ -613,19 +613,6 @@ class App(MiddlewareMixin, LifespanMixin):
         Returns:
             The backend api.
         """
-        if self._cached_fastapi_app is not None:
-            asgi_app = self._cached_fastapi_app
-
-            if not asgi_app or not self._api:
-                raise ValueError("The app has not been initialized.")
-
-            asgi_app.mount("", self._api)
-        else:
-            asgi_app = self._api
-
-            if not asgi_app:
-                raise ValueError("The app has not been initialized.")
-
         # For py3.9 compatibility when redis is used, we MUST add any decorator pages
         # before compiling the app in a thread to avoid event loop error (REF-2172).
         self._apply_decorated_pages()
@@ -637,9 +624,17 @@ class App(MiddlewareMixin, LifespanMixin):
             # Force background compile errors to print eagerly
             lambda f: f.result()
         )
-        # Wait for the compile to finish in prod mode to ensure all optional endpoints are mounted.
-        if is_prod_mode():
-            compile_future.result()
+        # Wait for the compile to finish to ensure all optional endpoints are mounted.
+        compile_future.result()
+
+        if not self._api:
+            raise ValueError("The app has not been initialized.")
+        if self._cached_fastapi_app is not None:
+            asgi_app = self._cached_fastapi_app
+            asgi_app.mount("", self._api)
+            App._add_cors(asgi_app)
+        else:
+            asgi_app = self._api
 
         if self.api_transformer is not None:
             api_transformers: Sequence[Starlette | Callable[[ASGIApp], ASGIApp]] = (
@@ -651,6 +646,7 @@ class App(MiddlewareMixin, LifespanMixin):
             for api_transformer in api_transformers:
                 if isinstance(api_transformer, Starlette):
                     # Mount the api to the fastapi app.
+                    App._add_cors(api_transformer)
                     api_transformer.mount("", asgi_app)
                     asgi_app = api_transformer
                 else:
@@ -709,11 +705,14 @@ class App(MiddlewareMixin, LifespanMixin):
         if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
             self.add_all_routes_endpoint()
 
-    def _add_cors(self):
-        """Add CORS middleware to the app."""
-        if not self._api:
-            return
-        self._api.add_middleware(
+    @staticmethod
+    def _add_cors(api: Starlette):
+        """Add CORS middleware to the app.
+
+        Args:
+            api: The Starlette app to add CORS middleware to.
+        """
+        api.add_middleware(
             cors.CORSMiddleware,
             allow_credentials=True,
             allow_methods=["*"],

+ 1 - 0
tests/units/test_app.py

@@ -1502,6 +1502,7 @@ def test_raise_on_state():
 def test_call_app():
     """Test that the app can be called."""
     app = App()
+    app._compile = unittest.mock.Mock()
     api = app()
     assert isinstance(api, Starlette)