|
@@ -3,6 +3,7 @@
|
|
|
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
|
|
|
|
|
|
from fastapi import FastAPI
|
|
|
+from fastapi.middleware import cors
|
|
|
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
|
|
|
|
|
from pynecone import constants, utils
|
|
@@ -74,6 +75,8 @@ class App(Base):
|
|
|
|
|
|
# Set up the API.
|
|
|
self.api = FastAPI()
|
|
|
+ self.add_cors(config.cors_allowed_origins)
|
|
|
+ self.add_default_endpoints()
|
|
|
|
|
|
# Set up CORS options.
|
|
|
cors_allowed_origins = config.cors_allowed_origins
|
|
@@ -116,6 +119,26 @@ class App(Base):
|
|
|
"""
|
|
|
return self.api
|
|
|
|
|
|
+ def add_default_endpoints(self):
|
|
|
+ """Add the default endpoints."""
|
|
|
+ # To test the server.
|
|
|
+ self.api.get(str(constants.Endpoint.PING))(ping)
|
|
|
+
|
|
|
+ def add_cors(self, allowed_origins: Optional[List[str]] = None):
|
|
|
+ """Add CORS middleware to the app.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ allowed_origins: A list of allowed origins.
|
|
|
+ """
|
|
|
+ allowed_origins = allowed_origins or ["*"]
|
|
|
+ self.api.add_middleware(
|
|
|
+ cors.CORSMiddleware,
|
|
|
+ allow_origins=allowed_origins,
|
|
|
+ allow_credentials=True,
|
|
|
+ allow_methods=["*"],
|
|
|
+ allow_headers=["*"],
|
|
|
+ )
|
|
|
+
|
|
|
def preprocess(self, state: State, event: Event) -> Optional[Delta]:
|
|
|
"""Preprocess the event.
|
|
|
|
|
@@ -392,6 +415,15 @@ async def process(
|
|
|
return update
|
|
|
|
|
|
|
|
|
+async def ping() -> str:
|
|
|
+ """Test API endpoint.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The response.
|
|
|
+ """
|
|
|
+ return "pong"
|
|
|
+
|
|
|
+
|
|
|
class EventNamespace(AsyncNamespace):
|
|
|
"""The event namespace."""
|
|
|
|