Forráskód Böngészése

starlette over fastapi (#5069)

* reintroduce python multipart for formdata

* starlette over fastapi

* fix the tests

* simplify json response

* use json response for all of these guys

* add transformer

* vendor types for future compatibility

* pre-commit

* we can actually just deprecate this guy

* update uv lock

* use api_transformer stuff

* fix the tests

* it's ruff out there
Khaleel Al-Adhami 3 hete
szülő
commit
71c1de681f

+ 125 - 44
reflex/app.py

@@ -13,22 +13,26 @@ import io
 import json
 import sys
 import traceback
-from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping
+from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
 from datetime import datetime
 from pathlib import Path
 from timeit import default_timer as timer
 from types import SimpleNamespace
 from typing import TYPE_CHECKING, Any, BinaryIO, get_args, get_type_hints
 
-from fastapi import FastAPI, HTTPException, Request
-from fastapi import UploadFile as FastAPIUploadFile
-from fastapi.middleware import cors
-from fastapi.responses import JSONResponse, StreamingResponse
-from fastapi.staticfiles import StaticFiles
+from fastapi import FastAPI
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
-from socketio import ASGIApp, AsyncNamespace, AsyncServer
+from socketio import ASGIApp as EngineIOApp
+from socketio import AsyncNamespace, AsyncServer
+from starlette.applications import Starlette
 from starlette.datastructures import Headers
 from starlette.datastructures import UploadFile as StarletteUploadFile
+from starlette.exceptions import HTTPException
+from starlette.middleware import cors
+from starlette.requests import Request
+from starlette.responses import JSONResponse, Response, StreamingResponse
+from starlette.staticfiles import StaticFiles
+from typing_extensions import deprecated
 
 from reflex import constants
 from reflex.admin import AdminDash
@@ -102,6 +106,7 @@ from reflex.utils import (
 )
 from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env
 from reflex.utils.imports import ImportVar
+from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send
 
 if TYPE_CHECKING:
     from reflex.vars import Var
@@ -389,7 +394,7 @@ class App(MiddlewareMixin, LifespanMixin):
     _stateful_pages: dict[str, None] = dataclasses.field(default_factory=dict)
 
     # The backend API object.
-    _api: FastAPI | None = None
+    _api: Starlette | None = None
 
     # The state class to use for the app.
     _state: type[BaseState] | None = None
@@ -424,14 +429,34 @@ class App(MiddlewareMixin, LifespanMixin):
     # Put the toast provider in the app wrap.
     toaster: Component | None = dataclasses.field(default_factory=toast.provider)
 
+    # Transform the ASGI app before running it.
+    api_transformer: (
+        Sequence[Callable[[ASGIApp], ASGIApp] | Starlette]
+        | Callable[[ASGIApp], ASGIApp]
+        | Starlette
+        | None
+    ) = None
+
+    # FastAPI app for compatibility with FastAPI.
+    _cached_fastapi_app: FastAPI | None = None
+
     @property
-    def api(self) -> FastAPI | None:
+    @deprecated("Use `api_transformer=your_fastapi_app` instead.")
+    def api(self) -> FastAPI:
         """Get the backend api.
 
         Returns:
             The backend api.
         """
-        return self._api
+        if self._cached_fastapi_app is None:
+            self._cached_fastapi_app = FastAPI()
+        console.deprecate(
+            feature_name="App.api",
+            reason="Set `api_transformer=your_fastapi_app` instead.",
+            deprecation_version="0.7.9",
+            removal_version="0.8.0",
+        )
+        return self._cached_fastapi_app
 
     @property
     def event_namespace(self) -> EventNamespace | None:
@@ -463,7 +488,7 @@ class App(MiddlewareMixin, LifespanMixin):
             set_breakpoints(self.style.pop("breakpoints"))
 
         # Set up the API.
-        self._api = FastAPI(lifespan=self._run_lifespan_tasks)
+        self._api = Starlette(lifespan=self._run_lifespan_tasks)
         self._add_cors()
         self._add_default_endpoints()
 
@@ -529,7 +554,7 @@ class App(MiddlewareMixin, LifespanMixin):
             )
 
         # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
-        socket_app = ASGIApp(self.sio, socketio_path="")
+        socket_app = EngineIOApp(self.sio, socketio_path="")
         namespace = config.get_event_namespace()
 
         # Create the event namespace and attach the main app. Not related to any paths.
@@ -538,18 +563,16 @@ class App(MiddlewareMixin, LifespanMixin):
         # Register the event namespace with the socket.
         self.sio.register_namespace(self.event_namespace)
         # Mount the socket app with the API.
-        if self.api:
+        if self._api:
 
             class HeaderMiddleware:
                 def __init__(self, app: ASGIApp):
                     self.app = app
 
-                async def __call__(
-                    self, scope: MutableMapping[str, Any], receive: Any, send: Callable
-                ):
+                async def __call__(self, scope: Scope, receive: Receive, send: Send):
                     original_send = send
 
-                    async def modified_send(message: dict):
+                    async def modified_send(message: Message):
                         if message["type"] == "websocket.accept":
                             if scope.get("subprotocols"):
                                 # The following *does* say "subprotocol" instead of "subprotocols", intentionally.
@@ -568,7 +591,7 @@ class App(MiddlewareMixin, LifespanMixin):
                     return await self.app(scope, receive, modified_send)
 
             socket_app_with_headers = HeaderMiddleware(socket_app)
-            self.api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)
+            self._api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)
 
         # Check the exception handlers
         self._validate_exception_handlers()
@@ -581,7 +604,7 @@ class App(MiddlewareMixin, LifespanMixin):
         """
         return f"<App state={self._state.__name__ if self._state else None}>"
 
-    def __call__(self) -> FastAPI:
+    def __call__(self) -> ASGIApp:
         """Run the backend api instance.
 
         Raises:
@@ -590,8 +613,18 @@ class App(MiddlewareMixin, LifespanMixin):
         Returns:
             The backend api.
         """
-        if not self.api:
-            raise ValueError("The app has not been initialized.")
+        if self._cached_fastapi_app is not None:
+            asgi_app = self._cached_fastapi_app
+
+            if not asgi_app or not self._api:
+                raise ValueError("The app has not been initialized.")
+
+            asgi_app.mount("", self._api)
+        else:
+            asgi_app = self._api
+
+            if not asgi_app:
+                raise ValueError("The app has not been initialized.")
 
         # For py3.9 compatibility when redis is used, we MUST add any decorator pages
         # before compiling the app in a thread to avoid event loop error (REF-2172).
@@ -608,30 +641,58 @@ class App(MiddlewareMixin, LifespanMixin):
         if is_prod_mode():
             compile_future.result()
 
-        return self.api
+        if self.api_transformer is not None:
+            api_transformers: Sequence[Starlette | Callable[[ASGIApp], ASGIApp]] = (
+                [self.api_transformer]
+                if not isinstance(self.api_transformer, Sequence)
+                else self.api_transformer
+            )
+
+            for api_transformer in api_transformers:
+                if isinstance(api_transformer, Starlette):
+                    # Mount the api to the fastapi app.
+                    api_transformer.mount("", asgi_app)
+                    asgi_app = api_transformer
+                else:
+                    # Transform the asgi app.
+                    asgi_app = api_transformer(asgi_app)
+
+        return asgi_app
 
     def _add_default_endpoints(self):
         """Add default api endpoints (ping)."""
         # To test the server.
-        if not self.api:
+        if not self._api:
             return
 
-        self.api.get(str(constants.Endpoint.PING))(ping)
-        self.api.get(str(constants.Endpoint.HEALTH))(health)
+        self._api.add_route(
+            str(constants.Endpoint.PING),
+            ping,
+            methods=["GET"],
+        )
+        self._api.add_route(
+            str(constants.Endpoint.HEALTH),
+            health,
+            methods=["GET"],
+        )
 
     def _add_optional_endpoints(self):
         """Add optional api endpoints (_upload)."""
-        if not self.api:
+        if not self._api:
             return
         upload_is_used_marker = (
             prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED
         )
         if Upload.is_used or upload_is_used_marker.exists():
             # To upload files.
-            self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
+            self._api.add_route(
+                str(constants.Endpoint.UPLOAD),
+                upload(self),
+                methods=["POST"],
+            )
 
             # To access uploaded files.
-            self.api.mount(
+            self._api.mount(
                 str(constants.Endpoint.UPLOAD),
                 StaticFiles(directory=get_upload_dir()),
                 name="uploaded_files",
@@ -640,17 +701,19 @@ class App(MiddlewareMixin, LifespanMixin):
             upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
             upload_is_used_marker.touch()
         if codespaces.is_running_in_codespaces():
-            self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
-                codespaces.auth_codespace
+            self._api.add_route(
+                str(constants.Endpoint.AUTH_CODESPACE),
+                codespaces.auth_codespace,
+                methods=["GET"],
             )
         if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
             self.add_all_routes_endpoint()
 
     def _add_cors(self):
         """Add CORS middleware to the app."""
-        if not self.api:
+        if not self._api:
             return
-        self.api.add_middleware(
+        self._api.add_middleware(
             cors.CORSMiddleware,
             allow_credentials=True,
             allow_methods=["*"],
@@ -915,7 +978,7 @@ class App(MiddlewareMixin, LifespanMixin):
             return
 
         # Get the admin dash.
-        if not self.api:
+        if not self._api:
             return
 
         admin_dash = self.admin_dash
@@ -936,7 +999,7 @@ class App(MiddlewareMixin, LifespanMixin):
                 view = admin_dash.view_overrides.get(model, ModelView)
                 admin.add_view(view(model))
 
-            admin.mount_to(self.api)
+            admin.mount_to(self._api)
 
     def _get_frontend_packages(self, imports: dict[str, set[ImportVar]]):
         """Gets the frontend packages to be installed and filters out the unnecessary ones.
@@ -1427,12 +1490,15 @@ class App(MiddlewareMixin, LifespanMixin):
 
     def add_all_routes_endpoint(self):
         """Add an endpoint to the app that returns all the routes."""
-        if not self.api:
+        if not self._api:
             return
 
-        @self.api.get(str(constants.Endpoint.ALL_ROUTES))
-        async def all_routes():
-            return list(self._unevaluated_pages.keys())
+        async def all_routes(_request: Request) -> Response:
+            return JSONResponse(list(self._unevaluated_pages.keys()))
+
+        self._api.add_route(
+            str(constants.Endpoint.ALL_ROUTES), all_routes, methods=["GET"]
+        )
 
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
@@ -1687,18 +1753,24 @@ async def process(
         raise
 
 
-async def ping() -> str:
+async def ping(_request: Request) -> Response:
     """Test API endpoint.
 
+    Args:
+        _request: The Starlette request object.
+
     Returns:
         The response.
     """
-    return "pong"
+    return JSONResponse("pong")
 
 
-async def health() -> JSONResponse:
+async def health(_request: Request) -> JSONResponse:
     """Health check endpoint to assess the status of the database and Redis services.
 
+    Args:
+        _request: The Starlette request object.
+
     Returns:
         JSONResponse: A JSON object with the health status:
             - "status" (bool): Overall health, True if all checks pass.
@@ -1740,12 +1812,11 @@ def upload(app: App):
         The upload function.
     """
 
-    async def upload_file(request: Request, files: list[FastAPIUploadFile]):
+    async def upload_file(request: Request):
         """Upload a file.
 
         Args:
-            request: The FastAPI request object.
-            files: The file(s) to upload.
+            request: The Starlette request object.
 
         Returns:
             StreamingResponse yielding newline-delimited JSON of StateUpdate
@@ -1758,6 +1829,12 @@ def upload(app: App):
         """
         from reflex.utils.exceptions import UploadTypeError, UploadValueError
 
+        # Get the files from the request.
+        files = await request.form()
+        files = files.getlist("files")
+        if not files:
+            raise UploadValueError("No files were uploaded.")
+
         token = request.headers.get("reflex-client-token")
         handler = request.headers.get("reflex-event-handler")
 
@@ -1810,6 +1887,10 @@ def upload(app: App):
         # event is handled.
         file_copies = []
         for file in files:
+            if not isinstance(file, StarletteUploadFile):
+                raise UploadValueError(
+                    "Uploaded file is not an UploadFile." + str(file)
+                )
             content_copy = io.BytesIO()
             content_copy.write(await file.read())
             content_copy.seek(0)

+ 2 - 2
reflex/app_mixins/lifespan.py

@@ -9,7 +9,7 @@ import functools
 import inspect
 from collections.abc import Callable, Coroutine
 
-from fastapi import FastAPI
+from starlette.applications import Starlette
 
 from reflex.utils import console
 from reflex.utils.exceptions import InvalidLifespanTaskTypeError
@@ -27,7 +27,7 @@ class LifespanMixin(AppMixin):
     )
 
     @contextlib.asynccontextmanager
-    async def _run_lifespan_tasks(self, app: FastAPI):
+    async def _run_lifespan_tasks(self, app: Starlette):
         running_tasks = []
         try:
             async with contextlib.AsyncExitStack() as stack:

+ 2 - 2
reflex/testing.py

@@ -322,11 +322,11 @@ class AppHarness:
         return _shutdown
 
     def _start_backend(self, port: int = 0):
-        if self.app_instance is None or self.app_instance.api is None:
+        if self.app_instance is None or self.app_instance._api is None:
             raise RuntimeError("App was not initialized.")
         self.backend = uvicorn.Server(
             uvicorn.Config(
-                app=self.app_instance.api,
+                app=self.app_instance._api,
                 host="127.0.0.1",
                 port=port,
             )

+ 6 - 2
reflex/utils/codespaces.py

@@ -4,7 +4,8 @@ from __future__ import annotations
 
 import os
 
-from fastapi.responses import HTMLResponse
+from starlette.requests import Request
+from starlette.responses import HTMLResponse
 
 from reflex.components.base.script import Script
 from reflex.components.component import Component
@@ -74,9 +75,12 @@ def codespaces_auto_redirect() -> list[Component]:
     return []
 
 
-async def auth_codespace() -> HTMLResponse:
+async def auth_codespace(_request: Request) -> HTMLResponse:
     """Page automatically redirecting back to the app after authenticating a codespace port forward.
 
+    Args:
+        _request: The request object.
+
     Returns:
         An HTML response with an embedded script to redirect back to the app.
     """

+ 9 - 0
reflex/utils/types.py

@@ -11,11 +11,13 @@ from types import GenericAlias
 from typing import (  # noqa: UP035
     TYPE_CHECKING,
     Any,
+    Awaitable,
     ClassVar,
     Dict,
     ForwardRef,
     List,
     Literal,
+    MutableMapping,
     NoReturn,
     Tuple,
     Union,
@@ -73,6 +75,13 @@ if TYPE_CHECKING:
 else:
     ArgsSpec = Callable[..., list[Any]]
 
+Scope = MutableMapping[str, Any]
+Message = MutableMapping[str, Any]
+
+Receive = Callable[[], Awaitable[Message]]
+Send = Callable[[Message], Awaitable[None]]
+
+ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
 
 PrimitiveToAnnotation = {
     list: List,  # noqa: UP006

+ 1 - 1
tests/integration/test_lifespan.py

@@ -1,4 +1,4 @@
-"""Test cases for the FastAPI lifespan integration."""
+"""Test cases for the Starlette lifespan integration."""
 
 from collections.abc import Generator
 

+ 48 - 7
tests/units/test_app.py

@@ -14,8 +14,9 @@ from unittest.mock import AsyncMock
 
 import pytest
 import sqlmodel
-from fastapi import FastAPI, UploadFile
 from pytest_mock import MockerFixture
+from starlette.applications import Starlette
+from starlette.datastructures import UploadFile
 from starlette_admin.auth import AuthProvider
 from starlette_admin.contrib.sqla.admin import Admin
 from starlette_admin.contrib.sqla.view import ModelView
@@ -813,8 +814,22 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
         filename="image2.jpg",
         file=bio,
     )
+
+    async def form():
+        files_mock = unittest.mock.Mock()
+
+        def getlist(key: str):
+            assert key == "files"
+            return [file1, file2]
+
+        files_mock.getlist = getlist
+
+        return files_mock
+
+    request_mock.form = form
+
     upload_fn = upload(app)
-    streaming_response = await upload_fn(request_mock, [file1, file2])  # pyright: ignore [reportFunctionMemberAccess]
+    streaming_response = await upload_fn(request_mock)
     async for state_update in streaming_response.body_iterator:
         assert (
             state_update
@@ -853,10 +868,23 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
         "reflex-client-token": token,
         "reflex-event-handler": f"{state.get_full_name()}.handle_upload2",
     }
-    file_mock = unittest.mock.Mock(filename="image1.jpg")
+
+    async def form():
+        files_mock = unittest.mock.Mock()
+
+        def getlist(key: str):
+            assert key == "files"
+            return [unittest.mock.Mock(filename="image1.jpg")]
+
+        files_mock.getlist = getlist
+
+        return files_mock
+
+    request_mock.form = form
+
     fn = upload(app)
     with pytest.raises(ValueError) as err:
-        await fn(request_mock, [file_mock])
+        await fn(request_mock)
     assert (
         err.value.args[0]
         == f"`{state.get_full_name()}.handle_upload2` handler should have a parameter annotated as list[rx.UploadFile]"
@@ -887,10 +915,23 @@ async def test_upload_file_background(state, tmp_path, token):
         "reflex-client-token": token,
         "reflex-event-handler": f"{state.get_full_name()}.bg_upload",
     }
-    file_mock = unittest.mock.Mock(filename="image1.jpg")
+
+    async def form():
+        files_mock = unittest.mock.Mock()
+
+        def getlist(key: str):
+            assert key == "files"
+            return [unittest.mock.Mock(filename="image1.jpg")]
+
+        files_mock.getlist = getlist
+
+        return files_mock
+
+    request_mock.form = form
+
     fn = upload(app)
     with pytest.raises(TypeError) as err:
-        await fn(request_mock, [file_mock])
+        await fn(request_mock)
     assert (
         err.value.args[0]
         == f"@rx.event(background=True) is not supported for upload handler `{state.get_full_name()}.bg_upload`."
@@ -1462,7 +1503,7 @@ def test_call_app():
     """Test that the app can be called."""
     app = App()
     api = app()
-    assert isinstance(api, FastAPI)
+    assert isinstance(api, Starlette)
 
 
 def test_app_with_optional_endpoints():

+ 3 - 1
tests/units/test_health_endpoint.py

@@ -119,8 +119,10 @@ async def test_health(
         return_value={"redis": redis_status},
     )
 
+    request = Mock()
+
     # Call the async health function
-    response = await health()
+    response = await health(request)
 
     # Verify the response content and status code
     assert response.status_code == expected_code