fly.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. from urllib.parse import parse_qs
  3. from starlette.middleware.base import BaseHTTPMiddleware
  4. from starlette.types import ASGIApp, Receive, Scope, Send
  5. from nicegui import app
  6. def setup() -> bool:
  7. """Setup fly.io specific settings.
  8. Returns True if running on fly.io, False otherwise.
  9. """
  10. if 'FLY_ALLOC_ID' not in os.environ:
  11. return False
  12. class FlyReplayMiddleware(BaseHTTPMiddleware):
  13. """Replay to correct fly.io instance.
  14. If the wrong instance was picked by the fly.io load balancer, we use the fly-replay header
  15. to repeat the request again on the right instance.
  16. This only works if the correct instance is provided as a query_string parameter.
  17. """
  18. def __init__(self, app: ASGIApp) -> None:
  19. self.app = app
  20. self.app_name = os.environ.get('FLY_APP_NAME')
  21. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  22. query_string = scope.get('query_string', b'').decode()
  23. query_params = parse_qs(query_string)
  24. target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
  25. async def send_wrapper(message):
  26. if target_instance != fly_instance_id and self.is_online(target_instance):
  27. if message['type'] == 'websocket.close':
  28. # fly.io only seems to look at the fly-replay header if websocket is accepted
  29. message = {'type': 'websocket.accept'}
  30. if 'headers' not in message:
  31. message['headers'] = []
  32. message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
  33. await send(message)
  34. await self.app(scope, receive, send_wrapper)
  35. def is_online(self, fly_instance_id: str) -> bool:
  36. hostname = f'{fly_instance_id}.vm.{self.app_name}.internal'
  37. try:
  38. dns.resolver.resolve(hostname, 'AAAA')
  39. return True
  40. except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.resolver.NoNameservers, dns.resolver.Timeout):
  41. return False
  42. # NOTE In our global fly.io deployment we need to make sure that we connect back to the same instance.
  43. fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
  44. app.config.socket_io_js_extra_headers['fly-force-instance-id'] = fly_instance_id # for HTTP long polling
  45. app.config.socket_io_js_query_params['fly_instance_id'] = fly_instance_id # for websocket (FlyReplayMiddleware)
  46. import dns.resolver # NOTE only import on fly where we have it installed to look up if instance is still available
  47. app.add_middleware(FlyReplayMiddleware)
  48. return True