|
@@ -4,10 +4,13 @@ import inspect
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
from typing import Awaitable, Callable, Optional
|
|
|
+from urllib.parse import parse_qs
|
|
|
|
|
|
from fastapi import Request
|
|
|
from fastapi.responses import FileResponse, RedirectResponse, Response
|
|
|
+from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
|
+from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
|
|
import prometheus
|
|
|
from nicegui import Client, app
|
|
@@ -52,12 +55,41 @@ async def redirect_reference_to_documentation(request: Request,
|
|
|
return RedirectResponse('/documentation')
|
|
|
return await call_next(request)
|
|
|
|
|
|
-# NOTE in our global fly.io deployment we need to make sure that the websocket connects back to the same instance
|
|
|
-fly_instance_id = os.environ.get('FLY_ALLOC_ID', '').split('-')[0]
|
|
|
-if fly_instance_id:
|
|
|
- nicegui_globals.socket_io_js_extra_headers['fly-force-instance-id'] = fly_instance_id
|
|
|
- # NOTE polling is required for fly.io to use the force-instance header
|
|
|
- nicegui_globals.socket_io_js_transports = ['polling']
|
|
|
+# NOTE in our global fly.io deployment we need to make sure that we connect back to the same instance.
|
|
|
+fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
|
|
+nicegui_globals.socket_io_js_extra_headers['fly-force-instance-id'] = fly_instance_id # for http long polling
|
|
|
+nicegui_globals.socket_io_js_query_params['fly_instance_id'] = fly_instance_id # for websocket (FlyReplayMiddleware)
|
|
|
+
|
|
|
+
|
|
|
+class FlyReplayMiddleware(BaseHTTPMiddleware):
|
|
|
+ """
|
|
|
+ If the wrong instance was picked by the fly.io load balancer we use the fly-replay header
|
|
|
+ to repeat the request again on the right instance.
|
|
|
+
|
|
|
+ This only works if the correct instance is provided as a query_string parameter.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, app: ASGIApp) -> None:
|
|
|
+ self.app = app
|
|
|
+
|
|
|
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
|
+ query_string = scope.get('query_string', b'').decode()
|
|
|
+ query_params = parse_qs(query_string)
|
|
|
+ target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
|
|
|
+
|
|
|
+ async def send_wrapper(message):
|
|
|
+ if target_instance != fly_instance_id:
|
|
|
+ if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']:
|
|
|
+ # fly.io only seems to look at the fly-replay header if websocket is accepted
|
|
|
+ message = {'type': 'websocket.accept'}
|
|
|
+ if 'headers' not in message:
|
|
|
+ message['headers'] = []
|
|
|
+ message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
|
|
|
+ await send(message)
|
|
|
+ await self.app(scope, receive, send_wrapper)
|
|
|
+
|
|
|
+
|
|
|
+app.add_middleware(FlyReplayMiddleware)
|
|
|
|
|
|
|
|
|
def add_head_html() -> None:
|