|
@@ -13,22 +13,26 @@ import io
|
|
|
import json
|
|
|
import sys
|
|
|
import traceback
|
|
|
-from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping
|
|
|
+from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
|
|
|
from datetime import datetime
|
|
|
from pathlib import Path
|
|
|
from timeit import default_timer as timer
|
|
|
from types import SimpleNamespace
|
|
|
from typing import TYPE_CHECKING, Any, BinaryIO, get_args, get_type_hints
|
|
|
|
|
|
-from fastapi import FastAPI, HTTPException, Request
|
|
|
-from fastapi import UploadFile as FastAPIUploadFile
|
|
|
-from fastapi.middleware import cors
|
|
|
-from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
-from fastapi.staticfiles import StaticFiles
|
|
|
+from fastapi import FastAPI
|
|
|
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
|
|
|
-from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
|
|
+from socketio import ASGIApp as EngineIOApp
|
|
|
+from socketio import AsyncNamespace, AsyncServer
|
|
|
+from starlette.applications import Starlette
|
|
|
from starlette.datastructures import Headers
|
|
|
from starlette.datastructures import UploadFile as StarletteUploadFile
|
|
|
+from starlette.exceptions import HTTPException
|
|
|
+from starlette.middleware import cors
|
|
|
+from starlette.requests import Request
|
|
|
+from starlette.responses import JSONResponse, Response, StreamingResponse
|
|
|
+from starlette.staticfiles import StaticFiles
|
|
|
+from typing_extensions import deprecated
|
|
|
|
|
|
from reflex import constants
|
|
|
from reflex.admin import AdminDash
|
|
@@ -102,6 +106,7 @@ from reflex.utils import (
|
|
|
)
|
|
|
from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env
|
|
|
from reflex.utils.imports import ImportVar
|
|
|
+from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from reflex.vars import Var
|
|
@@ -389,7 +394,7 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
_stateful_pages: dict[str, None] = dataclasses.field(default_factory=dict)
|
|
|
|
|
|
# The backend API object.
|
|
|
- _api: FastAPI | None = None
|
|
|
+ _api: Starlette | None = None
|
|
|
|
|
|
# The state class to use for the app.
|
|
|
_state: type[BaseState] | None = None
|
|
@@ -424,14 +429,34 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
# Put the toast provider in the app wrap.
|
|
|
toaster: Component | None = dataclasses.field(default_factory=toast.provider)
|
|
|
|
|
|
+ # Transform the ASGI app before running it.
|
|
|
+ api_transformer: (
|
|
|
+ Sequence[Callable[[ASGIApp], ASGIApp] | Starlette]
|
|
|
+ | Callable[[ASGIApp], ASGIApp]
|
|
|
+ | Starlette
|
|
|
+ | None
|
|
|
+ ) = None
|
|
|
+
|
|
|
+ # FastAPI app for compatibility with FastAPI.
|
|
|
+ _cached_fastapi_app: FastAPI | None = None
|
|
|
+
|
|
|
@property
|
|
|
- def api(self) -> FastAPI | None:
|
|
|
+ @deprecated("Use `api_transformer=your_fastapi_app` instead.")
|
|
|
+ def api(self) -> FastAPI:
|
|
|
"""Get the backend api.
|
|
|
|
|
|
Returns:
|
|
|
The backend api.
|
|
|
"""
|
|
|
- return self._api
|
|
|
+ if self._cached_fastapi_app is None:
|
|
|
+ self._cached_fastapi_app = FastAPI()
|
|
|
+ console.deprecate(
|
|
|
+ feature_name="App.api",
|
|
|
+ reason="Set `api_transformer=your_fastapi_app` instead.",
|
|
|
+ deprecation_version="0.7.9",
|
|
|
+ removal_version="0.8.0",
|
|
|
+ )
|
|
|
+ return self._cached_fastapi_app
|
|
|
|
|
|
@property
|
|
|
def event_namespace(self) -> EventNamespace | None:
|
|
@@ -463,7 +488,7 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
set_breakpoints(self.style.pop("breakpoints"))
|
|
|
|
|
|
# Set up the API.
|
|
|
- self._api = FastAPI(lifespan=self._run_lifespan_tasks)
|
|
|
+ self._api = Starlette(lifespan=self._run_lifespan_tasks)
|
|
|
self._add_cors()
|
|
|
self._add_default_endpoints()
|
|
|
|
|
@@ -529,7 +554,7 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
)
|
|
|
|
|
|
# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
|
|
|
- socket_app = ASGIApp(self.sio, socketio_path="")
|
|
|
+ socket_app = EngineIOApp(self.sio, socketio_path="")
|
|
|
namespace = config.get_event_namespace()
|
|
|
|
|
|
# Create the event namespace and attach the main app. Not related to any paths.
|
|
@@ -538,18 +563,16 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
# Register the event namespace with the socket.
|
|
|
self.sio.register_namespace(self.event_namespace)
|
|
|
# Mount the socket app with the API.
|
|
|
- if self.api:
|
|
|
+ if self._api:
|
|
|
|
|
|
class HeaderMiddleware:
|
|
|
def __init__(self, app: ASGIApp):
|
|
|
self.app = app
|
|
|
|
|
|
- async def __call__(
|
|
|
- self, scope: MutableMapping[str, Any], receive: Any, send: Callable
|
|
|
- ):
|
|
|
+ async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
|
|
original_send = send
|
|
|
|
|
|
- async def modified_send(message: dict):
|
|
|
+ async def modified_send(message: Message):
|
|
|
if message["type"] == "websocket.accept":
|
|
|
if scope.get("subprotocols"):
|
|
|
# The following *does* say "subprotocol" instead of "subprotocols", intentionally.
|
|
@@ -568,7 +591,7 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
return await self.app(scope, receive, modified_send)
|
|
|
|
|
|
socket_app_with_headers = HeaderMiddleware(socket_app)
|
|
|
- self.api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)
|
|
|
+ self._api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers)
|
|
|
|
|
|
# Check the exception handlers
|
|
|
self._validate_exception_handlers()
|
|
@@ -581,7 +604,7 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
"""
|
|
|
return f"<App state={self._state.__name__ if self._state else None}>"
|
|
|
|
|
|
- def __call__(self) -> FastAPI:
|
|
|
+ def __call__(self) -> ASGIApp:
|
|
|
"""Run the backend api instance.
|
|
|
|
|
|
Raises:
|
|
@@ -590,8 +613,18 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
Returns:
|
|
|
The backend api.
|
|
|
"""
|
|
|
- if not self.api:
|
|
|
- raise ValueError("The app has not been initialized.")
|
|
|
+ if self._cached_fastapi_app is not None:
|
|
|
+ asgi_app = self._cached_fastapi_app
|
|
|
+
|
|
|
+ if not asgi_app or not self._api:
|
|
|
+ raise ValueError("The app has not been initialized.")
|
|
|
+
|
|
|
+ asgi_app.mount("", self._api)
|
|
|
+ else:
|
|
|
+ asgi_app = self._api
|
|
|
+
|
|
|
+ if not asgi_app:
|
|
|
+ raise ValueError("The app has not been initialized.")
|
|
|
|
|
|
# For py3.9 compatibility when redis is used, we MUST add any decorator pages
|
|
|
# before compiling the app in a thread to avoid event loop error (REF-2172).
|
|
@@ -608,30 +641,58 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
if is_prod_mode():
|
|
|
compile_future.result()
|
|
|
|
|
|
- return self.api
|
|
|
+ if self.api_transformer is not None:
|
|
|
+ api_transformers: Sequence[Starlette | Callable[[ASGIApp], ASGIApp]] = (
|
|
|
+ [self.api_transformer]
|
|
|
+ if not isinstance(self.api_transformer, Sequence)
|
|
|
+ else self.api_transformer
|
|
|
+ )
|
|
|
+
|
|
|
+ for api_transformer in api_transformers:
|
|
|
+ if isinstance(api_transformer, Starlette):
|
|
|
+ # Mount the api to the fastapi app.
|
|
|
+ api_transformer.mount("", asgi_app)
|
|
|
+ asgi_app = api_transformer
|
|
|
+ else:
|
|
|
+ # Transform the asgi app.
|
|
|
+ asgi_app = api_transformer(asgi_app)
|
|
|
+
|
|
|
+ return asgi_app
|
|
|
|
|
|
def _add_default_endpoints(self):
|
|
|
"""Add default api endpoints (ping)."""
|
|
|
# To test the server.
|
|
|
- if not self.api:
|
|
|
+ if not self._api:
|
|
|
return
|
|
|
|
|
|
- self.api.get(str(constants.Endpoint.PING))(ping)
|
|
|
- self.api.get(str(constants.Endpoint.HEALTH))(health)
|
|
|
+ self._api.add_route(
|
|
|
+ str(constants.Endpoint.PING),
|
|
|
+ ping,
|
|
|
+ methods=["GET"],
|
|
|
+ )
|
|
|
+ self._api.add_route(
|
|
|
+ str(constants.Endpoint.HEALTH),
|
|
|
+ health,
|
|
|
+ methods=["GET"],
|
|
|
+ )
|
|
|
|
|
|
def _add_optional_endpoints(self):
|
|
|
"""Add optional api endpoints (_upload)."""
|
|
|
- if not self.api:
|
|
|
+ if not self._api:
|
|
|
return
|
|
|
upload_is_used_marker = (
|
|
|
prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED
|
|
|
)
|
|
|
if Upload.is_used or upload_is_used_marker.exists():
|
|
|
# To upload files.
|
|
|
- self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
|
|
|
+ self._api.add_route(
|
|
|
+ str(constants.Endpoint.UPLOAD),
|
|
|
+ upload(self),
|
|
|
+ methods=["POST"],
|
|
|
+ )
|
|
|
|
|
|
# To access uploaded files.
|
|
|
- self.api.mount(
|
|
|
+ self._api.mount(
|
|
|
str(constants.Endpoint.UPLOAD),
|
|
|
StaticFiles(directory=get_upload_dir()),
|
|
|
name="uploaded_files",
|
|
@@ -640,17 +701,19 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
|
|
|
upload_is_used_marker.touch()
|
|
|
if codespaces.is_running_in_codespaces():
|
|
|
- self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
|
|
|
- codespaces.auth_codespace
|
|
|
+ self._api.add_route(
|
|
|
+ str(constants.Endpoint.AUTH_CODESPACE),
|
|
|
+ codespaces.auth_codespace,
|
|
|
+ methods=["GET"],
|
|
|
)
|
|
|
if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
|
|
|
self.add_all_routes_endpoint()
|
|
|
|
|
|
def _add_cors(self):
|
|
|
"""Add CORS middleware to the app."""
|
|
|
- if not self.api:
|
|
|
+ if not self._api:
|
|
|
return
|
|
|
- self.api.add_middleware(
|
|
|
+ self._api.add_middleware(
|
|
|
cors.CORSMiddleware,
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
@@ -915,7 +978,7 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
return
|
|
|
|
|
|
# Get the admin dash.
|
|
|
- if not self.api:
|
|
|
+ if not self._api:
|
|
|
return
|
|
|
|
|
|
admin_dash = self.admin_dash
|
|
@@ -936,7 +999,7 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
view = admin_dash.view_overrides.get(model, ModelView)
|
|
|
admin.add_view(view(model))
|
|
|
|
|
|
- admin.mount_to(self.api)
|
|
|
+ admin.mount_to(self._api)
|
|
|
|
|
|
def _get_frontend_packages(self, imports: dict[str, set[ImportVar]]):
|
|
|
"""Gets the frontend packages to be installed and filters out the unnecessary ones.
|
|
@@ -1427,12 +1490,15 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
|
|
|
def add_all_routes_endpoint(self):
|
|
|
"""Add an endpoint to the app that returns all the routes."""
|
|
|
- if not self.api:
|
|
|
+ if not self._api:
|
|
|
return
|
|
|
|
|
|
- @self.api.get(str(constants.Endpoint.ALL_ROUTES))
|
|
|
- async def all_routes():
|
|
|
- return list(self._unevaluated_pages.keys())
|
|
|
+ async def all_routes(_request: Request) -> Response:
|
|
|
+ return JSONResponse(list(self._unevaluated_pages.keys()))
|
|
|
+
|
|
|
+ self._api.add_route(
|
|
|
+ str(constants.Endpoint.ALL_ROUTES), all_routes, methods=["GET"]
|
|
|
+ )
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
@@ -1687,18 +1753,24 @@ async def process(
|
|
|
raise
|
|
|
|
|
|
|
|
|
-async def ping() -> str:
|
|
|
+async def ping(_request: Request) -> Response:
|
|
|
"""Test API endpoint.
|
|
|
|
|
|
+ Args:
|
|
|
+ _request: The Starlette request object.
|
|
|
+
|
|
|
Returns:
|
|
|
The response.
|
|
|
"""
|
|
|
- return "pong"
|
|
|
+ return JSONResponse("pong")
|
|
|
|
|
|
|
|
|
-async def health() -> JSONResponse:
|
|
|
+async def health(_request: Request) -> JSONResponse:
|
|
|
"""Health check endpoint to assess the status of the database and Redis services.
|
|
|
|
|
|
+ Args:
|
|
|
+ _request: The Starlette request object.
|
|
|
+
|
|
|
Returns:
|
|
|
JSONResponse: A JSON object with the health status:
|
|
|
- "status" (bool): Overall health, True if all checks pass.
|
|
@@ -1740,12 +1812,11 @@ def upload(app: App):
|
|
|
The upload function.
|
|
|
"""
|
|
|
|
|
|
- async def upload_file(request: Request, files: list[FastAPIUploadFile]):
|
|
|
+ async def upload_file(request: Request):
|
|
|
"""Upload a file.
|
|
|
|
|
|
Args:
|
|
|
- request: The FastAPI request object.
|
|
|
- files: The file(s) to upload.
|
|
|
+ request: The Starlette request object.
|
|
|
|
|
|
Returns:
|
|
|
StreamingResponse yielding newline-delimited JSON of StateUpdate
|
|
@@ -1758,6 +1829,12 @@ def upload(app: App):
|
|
|
"""
|
|
|
from reflex.utils.exceptions import UploadTypeError, UploadValueError
|
|
|
|
|
|
+ # Get the files from the request.
|
|
|
+ files = await request.form()
|
|
|
+ files = files.getlist("files")
|
|
|
+ if not files:
|
|
|
+ raise UploadValueError("No files were uploaded.")
|
|
|
+
|
|
|
token = request.headers.get("reflex-client-token")
|
|
|
handler = request.headers.get("reflex-event-handler")
|
|
|
|
|
@@ -1810,6 +1887,10 @@ def upload(app: App):
|
|
|
# event is handled.
|
|
|
file_copies = []
|
|
|
for file in files:
|
|
|
+ if not isinstance(file, StarletteUploadFile):
|
|
|
+ raise UploadValueError(
|
|
|
+ "Uploaded file is not an UploadFile." + str(file)
|
|
|
+ )
|
|
|
content_copy = io.BytesIO()
|
|
|
content_copy.write(await file.read())
|
|
|
content_copy.seek(0)
|