fly.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import os
  2. from urllib.parse import parse_qs
  3. from starlette.middleware.base import BaseHTTPMiddleware
  4. from starlette.types import ASGIApp, Receive, Scope, Send
  5. from nicegui import app
  6. def setup() -> bool:
  7. """Setup fly.io specific settings.
  8. Returns True if running on fly.io, False otherwise.
  9. """
  10. if 'FLY_ALLOC_ID' not in os.environ:
  11. return False
  12. class FlyReplayMiddleware(BaseHTTPMiddleware):
  13. """Replay to correct fly.io instance.
  14. If the wrong instance was picked by the fly.io load balancer, we use the fly-replay header
  15. to repeat the request again on the right instance.
  16. This only works if the correct instance is provided as a query_string parameter.
  17. """
  18. def __init__(self, asgi_app: ASGIApp) -> None:
  19. super().__init__(asgi_app)
  20. self.app = asgi_app
  21. self.app_name = os.environ.get('FLY_APP_NAME')
  22. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  23. query_string = scope.get('query_string', b'').decode()
  24. query_params = parse_qs(query_string)
  25. target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
  26. async def send_wrapper(message):
  27. if target_instance != fly_instance_id and self.is_online(target_instance):
  28. if message['type'] == 'websocket.close':
  29. # fly.io only seems to look at the fly-replay header if websocket is accepted
  30. message = {'type': 'websocket.accept'}
  31. if 'headers' not in message:
  32. message['headers'] = []
  33. message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
  34. await send(message)
  35. await self.app(scope, receive, send_wrapper)
  36. def is_online(self, fly_instance_id: str) -> bool:
  37. hostname = f'{fly_instance_id}.vm.{self.app_name}.internal'
  38. try:
  39. dns.resolver.resolve(hostname, 'AAAA')
  40. return True
  41. except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.resolver.NoNameservers, dns.resolver.Timeout):
  42. return False
  43. # NOTE In our global fly.io deployment we need to make sure that we connect back to the same instance.
  44. fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
  45. app.config.socket_io_js_extra_headers['fly-force-instance-id'] = fly_instance_id # for HTTP long polling
  46. app.config.socket_io_js_query_params['fly_instance_id'] = fly_instance_id # for websocket (FlyReplayMiddleware)
  47. import dns.resolver # NOTE only import on fly where we have it installed to look up if instance is still available
  48. app.add_middleware(FlyReplayMiddleware)
  49. return True