Browse Source

using fly-replay for routing websockets
as demoed in https://github.com/zauberzeug/fly_fastapi_socketio

Rodja Trappe 1 year ago
parent
commit
79af0e82a5
4 changed files with 43 additions and 7 deletions
  1. 38 6
      main.py
  2. 2 0
      nicegui/client.py
  3. 2 0
      nicegui/globals.py
  4. 1 1
      nicegui/templates/index.html

+ 38 - 6
main.py

@@ -4,10 +4,13 @@ import inspect
 import os
 from pathlib import Path
 from typing import Awaitable, Callable, Optional
+from urllib.parse import parse_qs
 
 from fastapi import Request
 from fastapi.responses import FileResponse, RedirectResponse, Response
+from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.sessions import SessionMiddleware
+from starlette.types import ASGIApp, Receive, Scope, Send
 
 import prometheus
 from nicegui import Client, app
@@ -52,12 +55,41 @@ async def redirect_reference_to_documentation(request: Request,
         return RedirectResponse('/documentation')
     return await call_next(request)
 
-# NOTE in our global fly.io deployment we need to make sure that the websocket connects back to the same instance
-fly_instance_id = os.environ.get('FLY_ALLOC_ID', '').split('-')[0]
-if fly_instance_id:
-    nicegui_globals.socket_io_js_extra_headers['fly-force-instance-id'] = fly_instance_id
-    # NOTE polling is required for fly.io to use the force-instance header
-    nicegui_globals.socket_io_js_transports = ['polling']
+# NOTE in our global fly.io deployment we need to make sure that we connect back to the same instance.
+fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
+nicegui_globals.socket_io_js_extra_headers['fly-force-instance-id'] = fly_instance_id  # for http long polling
+nicegui_globals.socket_io_js_query_params['fly_instance_id'] = fly_instance_id  # for websocket (FlyReplayMiddleware)
+
+
+class FlyReplayMiddleware(BaseHTTPMiddleware):
+    """
+    If the wrong instance was picked by the fly.io load balancer we use the fly-replay header
+    to repeat the request again on the right instance.
+
+    This only works if the correct instance is provided as a query_string parameter.
+    """
+
+    def __init__(self, app: ASGIApp) -> None:
+        self.app = app
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        query_string = scope.get('query_string', b'').decode()
+        query_params = parse_qs(query_string)
+        target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
+
+        async def send_wrapper(message):
+            if target_instance != fly_instance_id:
+                if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']:
+                    # fly.io only seems to look at the fly-replay header if websocket is accepted
+                    message = {'type': 'websocket.accept'}
+                if 'headers' not in message:
+                    message['headers'] = []
+                message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
+            await send(message)
+        await self.app(scope, receive, send_wrapper)
+
+
+app.add_middleware(FlyReplayMiddleware)
 
 
 def add_head_html() -> None:

+ 2 - 0
nicegui/client.py

@@ -71,6 +71,7 @@ class Client:
     def build_response(self, request: Request, status_code: int = 200) -> Response:
         prefix = request.headers.get('X-Forwarded-Prefix', request.scope.get('root_path', ''))
         elements = json.dumps({id: element._to_dict() for id, element in self.elements.items()})
+        socket_io_js_query_params = globals.socket_io_js_query_params | {'client_id': self.id}
         vue_html, vue_styles, vue_scripts, imports, js_imports = generate_resources(prefix, self.elements.values())
         return templates.TemplateResponse('index.html', {
             'request': request,
@@ -89,6 +90,7 @@ class Client:
             'language': self.page.resolve_language(),
             'prefix': prefix,
             'tailwind': globals.tailwind,
+            'socket_io_js_query_params': socket_io_js_query_params,
             'socket_io_js_extra_headers': globals.socket_io_js_extra_headers,
             'socket_io_js_transports': globals.socket_io_js_transports,
         }, status_code, {'Cache-Control': 'no-store', 'X-NiceGUI-Content': 'page'})

+ 2 - 0
nicegui/globals.py

@@ -44,7 +44,9 @@ language: Language
 binding_refresh_interval: float
 tailwind: bool
 air: Optional['Air'] = None
+socket_io_js_query_params: Dict = {}
 socket_io_js_extra_headers: Dict = {}
+# NOTE we favour websocket over polling
 socket_io_js_transports: List[Literal['websocket', 'polling']] = ['websocket', 'polling']
 
 _socket_id: Optional[str] = None

+ 1 - 1
nicegui/templates/index.html

@@ -213,7 +213,7 @@
         },
         mounted() {
           window.app = this;
-          const query = { client_id: "{{ client_id }}" };
+          const query = {{ socket_io_js_query_params | safe }};
           const url = window.location.protocol === 'https:' ? 'wss://' : 'ws://' + window.location.host;
           const extraHeaders = {{ socket_io_js_extra_headers | safe }};
           const transports = {{ socket_io_js_transports | safe }};