fly.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import logging
  2. import os
  3. from urllib.parse import parse_qs
  4. from starlette.middleware.base import BaseHTTPMiddleware
  5. from starlette.types import ASGIApp, Receive, Scope, Send
  6. from nicegui import app
  7. def setup() -> bool:
  8. """Setup fly.io specific settings.
  9. Returns True if running on fly.io, False otherwise.
  10. """
  11. if 'FLY_ALLOC_ID' not in os.environ:
  12. return False
  13. class FlyReplayMiddleware(BaseHTTPMiddleware):
  14. """Replay to correct fly.io instance.
  15. If the wrong instance was picked by the fly.io load balancer, we use the fly-replay header
  16. to repeat the request again on the right instance.
  17. This only works if the correct instance is provided as a query_string parameter.
  18. """
  19. def __init__(self, app: ASGIApp) -> None:
  20. super().__init__(app)
  21. self.app = app
  22. self.app_name = os.environ.get('FLY_APP_NAME')
  23. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  24. query_string = scope.get('query_string', b'').decode()
  25. query_params = parse_qs(query_string)
  26. target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
  27. async def send_wrapper(message):
  28. if target_instance != fly_instance_id and self.is_online(target_instance):
  29. if message['type'] == 'websocket.close':
  30. # fly.io only seems to look at the fly-replay header if websocket is accepted
  31. message = {'type': 'websocket.accept'}
  32. if 'headers' not in message:
  33. message['headers'] = []
  34. message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
  35. await send(message)
  36. try:
  37. await self.app(scope, receive, send_wrapper)
  38. except RuntimeError as e:
  39. if 'No response returned.' in str(e):
  40. logging.warning(f'no response returned for {scope["path"]}')
  41. else:
  42. logging.exception('could not handle request')
  43. def is_online(self, fly_instance_id: str) -> bool:
  44. hostname = f'{fly_instance_id}.vm.{self.app_name}.internal'
  45. try:
  46. dns.resolver.resolve(hostname, 'AAAA')
  47. return True
  48. except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.resolver.NoNameservers, dns.resolver.Timeout):
  49. return False
  50. # NOTE In our global fly.io deployment we need to make sure that we connect back to the same instance.
  51. fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
  52. app.config.socket_io_js_extra_headers['fly-force-instance-id'] = fly_instance_id # for HTTP long polling
  53. app.config.socket_io_js_query_params['fly_instance_id'] = fly_instance_id # for websocket (FlyReplayMiddleware)
  54. import dns.resolver # NOTE only import on fly where we have it installed to look up if instance is still available
  55. app.add_middleware(FlyReplayMiddleware)
  56. return True