air.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from __future__ import annotations
  2. import asyncio
  3. import gzip
  4. import json
  5. import logging
  6. import re
  7. from dataclasses import dataclass
  8. from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Optional
  9. from uuid import uuid4
  10. import socketio
  11. import socketio.exceptions
  12. from . import background_tasks, core
  13. from .client import Client
  14. from .dataclasses import KWONLY_SLOTS
  15. from .elements.timer import Timer as timer
  16. from .logging import log
  17. if TYPE_CHECKING:
  18. import httpx
  19. RELAY_HOST = 'https://on-air.nicegui.io/'
  20. @dataclass(**KWONLY_SLOTS)
  21. class Stream:
  22. data: AsyncIterator[bytes]
  23. response: httpx.Response
  24. class Air:
  25. def __init__(self, token: str) -> None:
  26. import httpx # pylint: disable=import-outside-toplevel
  27. self.log = logging.getLogger('nicegui.air')
  28. self.token = token
  29. self.relay = socketio.AsyncClient()
  30. self.client = httpx.AsyncClient(transport=httpx.ASGITransport(app=core.app))
  31. self.streaming_client = httpx.AsyncClient()
  32. self.connecting = False
  33. self.streams: Dict[str, Stream] = {}
  34. self.remote_url: Optional[str] = None
  35. timer(5, self.connect) # ensure we stay connected
  36. @self.relay.on('http')
  37. async def _handle_http(data: Dict[str, Any]) -> Dict[str, Any]:
  38. headers: Dict[str, Any] = data['headers']
  39. headers.update({'Accept-Encoding': 'identity', 'X-Forwarded-Prefix': data['prefix']})
  40. url = 'http://test' + data['path']
  41. request = self.client.build_request(
  42. data['method'],
  43. url,
  44. params=data['params'],
  45. headers=headers,
  46. content=data['body'],
  47. )
  48. response = await self.client.send(request)
  49. self.client.cookies.clear()
  50. instance_id = data['instance-id']
  51. content = response.content.replace(
  52. b'const extraHeaders = {};',
  53. (f'const extraHeaders = {{ "fly-force-instance-id" : "{instance_id}" }};').encode(),
  54. )
  55. match = re.search(b'const query = ({.*?})', content)
  56. if match:
  57. new_js_object = match.group(1).decode().rstrip('}') + ", 'fly_instance_id' : '" + instance_id + "'}"
  58. content = content.replace(match.group(0), f'const query = {new_js_object}'.encode())
  59. compressed = gzip.compress(content)
  60. response.headers.update({'content-encoding': 'gzip', 'content-length': str(len(compressed))})
  61. return {
  62. 'status_code': response.status_code,
  63. 'headers': response.headers.multi_items(),
  64. 'content': compressed,
  65. }
  66. @self.relay.on('range-request')
  67. async def _handle_range_request(data: Dict[str, Any]) -> Dict[str, Any]:
  68. headers: Dict[str, Any] = data['headers']
  69. url = next(iter(u for u in core.app.urls if self.remote_url != u)) + data['path']
  70. data['params']['nicegui_chunk_size'] = 1024
  71. request = self.client.build_request(
  72. data['method'],
  73. url,
  74. params=data['params'],
  75. headers=headers,
  76. )
  77. response = await self.streaming_client.send(request, stream=True)
  78. stream_id = str(uuid4())
  79. self.streams[stream_id] = Stream(data=response.aiter_bytes(), response=response)
  80. return {
  81. 'status_code': response.status_code,
  82. 'headers': response.headers.multi_items(),
  83. 'stream_id': stream_id,
  84. }
  85. @self.relay.on('read-stream')
  86. async def _handle_read_stream(stream_id: str) -> Optional[bytes]:
  87. try:
  88. return await self.streams[stream_id].data.__anext__()
  89. except StopAsyncIteration:
  90. await _handle_close_stream(stream_id)
  91. return None
  92. except Exception:
  93. await _handle_close_stream(stream_id)
  94. raise
  95. @self.relay.on('close-stream')
  96. async def _handle_close_stream(stream_id: str) -> None:
  97. await self.streams[stream_id].response.aclose()
  98. del self.streams[stream_id]
  99. @self.relay.on('ready')
  100. def _handle_ready(data: Dict[str, Any]) -> None:
  101. core.app.urls.add(data['device_url'])
  102. self.remote_url = data['device_url']
  103. if core.app.config.show_welcome_message:
  104. print(f'NiceGUI is on air at {data["device_url"]}', flush=True)
  105. @self.relay.on('error')
  106. def _handleerror(data: Dict[str, Any]) -> None:
  107. print('Error:', data['message'], flush=True)
  108. @self.relay.on('handshake')
  109. def _handle_handshake(data: Dict[str, Any]) -> bool:
  110. client_id = data['client_id']
  111. if client_id not in Client.instances:
  112. return False
  113. client = Client.instances[client_id]
  114. client.environ = data['environ']
  115. if data.get('old_tab_id'):
  116. core.app.storage.copy_tab(data['old_tab_id'], data['tab_id'])
  117. client.tab_id = data['tab_id']
  118. client.on_air = True
  119. client.handle_handshake(data.get('next_message_id'))
  120. return True
  121. @self.relay.on('client_disconnect')
  122. def _handle_client_disconnect(data: Dict[str, Any]) -> None:
  123. self.log.debug('client disconnected.')
  124. client_id = data['client_id']
  125. if client_id not in Client.instances:
  126. return
  127. Client.instances[client_id].handle_disconnect()
  128. @self.relay.on('connect')
  129. async def _handle_connect() -> None:
  130. self.log.debug('connected.')
  131. @self.relay.on('disconnect')
  132. async def _handle_disconnect() -> None:
  133. self.log.debug('disconnected.')
  134. @self.relay.on('connect_error')
  135. async def _handle_connect_error(data) -> None:
  136. self.log.debug(f'Connection error: {data}')
  137. @self.relay.on('event')
  138. def _handle_event(data: Dict[str, Any]) -> None:
  139. client_id = data['client_id']
  140. if client_id not in Client.instances:
  141. return
  142. client = Client.instances[client_id]
  143. args = data['msg']['args']
  144. if args and isinstance(args[0], str) and args[0].startswith('{"socket_id":'):
  145. arg0 = json.loads(args[0])
  146. arg0['socket_id'] = client_id # HACK: translate socket_id of ui.scene's init event
  147. args[0] = json.dumps(arg0)
  148. client.handle_event(data['msg'])
  149. @self.relay.on('javascript_response')
  150. def _handle_javascript_response(data: Dict[str, Any]) -> None:
  151. client_id = data['client_id']
  152. if client_id not in Client.instances:
  153. return
  154. client = Client.instances[client_id]
  155. client.handle_javascript_response(data['msg'])
  156. @self.relay.on('ack')
  157. def _handle_ack(data: Dict[str, Any]) -> None:
  158. client_id = data['client_id']
  159. if client_id not in Client.instances:
  160. return
  161. client = Client.instances[client_id]
  162. client.outbox.prune_history(data['msg']['next_message_id'])
  163. @self.relay.on('out_of_time')
  164. async def _handle_out_of_time() -> None:
  165. print('Sorry, you have reached the time limit of this NiceGUI On Air preview.', flush=True)
  166. await self.connect()
  167. async def connect(self) -> None:
  168. """Connect to the NiceGUI On Air server."""
  169. if self.connecting:
  170. self.log.debug('Already connecting.')
  171. return
  172. if self.relay.connected:
  173. return
  174. self.log.debug('Going to connect...')
  175. self.connecting = True
  176. try:
  177. if self.relay.connected:
  178. await asyncio.wait_for(self.disconnect(), timeout=5)
  179. self.log.debug('Connecting...')
  180. await self.relay.connect(
  181. f'{RELAY_HOST}?device_token={self.token}',
  182. socketio_path='/on_air/socket.io',
  183. transports=['websocket', 'polling'], # favor websocket over polling
  184. wait_timeout=5,
  185. )
  186. assert self.relay.connected
  187. return
  188. except socketio.exceptions.ConnectionError:
  189. self.log.debug('Connection error.', stack_info=True)
  190. except ValueError: # NOTE this sometimes happens when the internal socketio client is not yet ready
  191. self.log.debug('ValueError while connecting.', stack_info=True)
  192. except Exception:
  193. log.exception('Could not connect to NiceGUI On Air server.')
  194. finally:
  195. self.connecting = False
  196. async def disconnect(self) -> None:
  197. """Disconnect from the NiceGUI On Air server."""
  198. self.log.debug('Disconnecting...')
  199. if self.relay.connected:
  200. await self.relay.disconnect()
  201. for stream in self.streams.values():
  202. await stream.response.aclose()
  203. self.streams.clear()
  204. self.log.debug('Disconnected.')
  205. async def emit(self, message_type: str, data: Dict[str, Any], room: str) -> None:
  206. """Emit a message to the NiceGUI On Air server."""
  207. if self.relay.connected:
  208. await self.relay.emit('forward', {'event': message_type, 'data': data, 'room': room})
  209. @staticmethod
  210. def is_air_target(target_id: str) -> bool:
  211. """Whether the given target ID is an On Air client or a SocketIO room."""
  212. if target_id in Client.instances:
  213. return Client.instances[target_id].on_air
  214. return target_id in core.sio.manager.rooms
  215. def connect() -> None:
  216. """Connect to the NiceGUI On Air server if there is an air instance."""
  217. if core.air:
  218. background_tasks.create(core.air.connect())
  219. def disconnect() -> None:
  220. """Disconnect from the NiceGUI On Air server if there is an air instance."""
  221. if core.air:
  222. background_tasks.create(core.air.disconnect())