nicegui.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import asyncio
  2. import time
  3. import urllib.parse
  4. from pathlib import Path
  5. from typing import Dict, Optional
  6. from fastapi import HTTPException, Request
  7. from fastapi.middleware.gzip import GZipMiddleware
  8. from fastapi.responses import FileResponse, Response
  9. from fastapi.staticfiles import StaticFiles
  10. from fastapi_socketio import SocketManager
  11. from nicegui import json
  12. from nicegui.json import NiceGUIJSONResponse
  13. from . import (__version__, background_tasks, binding, favicon, globals, outbox, # pylint: disable=redefined-builtin
  14. welcome)
  15. from .app import App
  16. from .client import Client
  17. from .dependencies import js_components, libraries
  18. from .element import Element
  19. from .error import error_content
  20. from .helpers import is_file, safe_invoke
  21. from .page import page
  22. globals.app = app = App(default_response_class=NiceGUIJSONResponse)
  23. # NOTE we use custom json module which wraps orjson
  24. socket_manager = SocketManager(app=app, mount_location='/_nicegui_ws/', json=json)
  25. globals.sio = sio = socket_manager._sio # pylint: disable=protected-access
  26. app.add_middleware(GZipMiddleware)
  27. static_files = StaticFiles(
  28. directory=(Path(__file__).parent / 'static').resolve(),
  29. follow_symlink=True,
  30. )
  31. app.mount(f'/_nicegui/{__version__}/static', static_files, name='static')
  32. globals.index_client = Client(page('/'), shared=True).__enter__() # pylint: disable=unnecessary-dunder-call
  33. @app.get('/')
  34. def index(request: Request) -> Response:
  35. return globals.index_client.build_response(request)
  36. @app.get(f'/_nicegui/{__version__}' + '/libraries/{key:path}')
  37. def get_library(key: str) -> FileResponse:
  38. is_map = key.endswith('.map')
  39. dict_key = key[:-4] if is_map else key
  40. if dict_key in libraries:
  41. path = libraries[dict_key].path
  42. if is_map:
  43. path = path.with_name(path.name + '.map')
  44. if path.exists():
  45. headers = {'Cache-Control': 'public, max-age=3600'}
  46. return FileResponse(path, media_type='text/javascript', headers=headers)
  47. raise HTTPException(status_code=404, detail=f'library "{key}" not found')
  48. @app.get(f'/_nicegui/{__version__}' + '/components/{key:path}')
  49. def get_component(key: str) -> FileResponse:
  50. if key in js_components and js_components[key].path.exists():
  51. headers = {'Cache-Control': 'public, max-age=3600'}
  52. return FileResponse(js_components[key].path, media_type='text/javascript', headers=headers)
  53. raise HTTPException(status_code=404, detail=f'component "{key}" not found')
  54. @app.on_event('startup')
  55. def handle_startup(with_welcome_message: bool = True) -> None:
  56. if not globals.ui_run_has_been_called:
  57. raise RuntimeError('\n\n'
  58. 'You must call ui.run() to start the server.\n'
  59. 'If ui.run() is behind a main guard\n'
  60. ' if __name__ == "__main__":\n'
  61. 'remove the guard or replace it with\n'
  62. ' if __name__ in {"__main__", "__mp_main__"}:\n'
  63. 'to allow for multiprocessing.')
  64. if globals.favicon:
  65. if is_file(globals.favicon):
  66. globals.app.add_route('/favicon.ico', lambda _: FileResponse(globals.favicon))
  67. else:
  68. globals.app.add_route('/favicon.ico', lambda _: favicon.get_favicon_response())
  69. else:
  70. globals.app.add_route('/favicon.ico', lambda _: FileResponse(Path(__file__).parent / 'static' / 'favicon.ico'))
  71. globals.state = globals.State.STARTING
  72. globals.loop = asyncio.get_running_loop()
  73. with globals.index_client:
  74. for t in globals.startup_handlers:
  75. safe_invoke(t)
  76. background_tasks.create(binding.loop())
  77. background_tasks.create(outbox.loop())
  78. background_tasks.create(prune_clients())
  79. background_tasks.create(prune_slot_stacks())
  80. globals.state = globals.State.STARTED
  81. if with_welcome_message:
  82. welcome.print_message()
  83. if globals.air:
  84. background_tasks.create(globals.air.connect())
  85. @app.on_event('shutdown')
  86. async def handle_shutdown() -> None:
  87. if app.native.main_window:
  88. app.native.main_window.signal_server_shutdown()
  89. globals.state = globals.State.STOPPING
  90. with globals.index_client:
  91. for t in globals.shutdown_handlers:
  92. safe_invoke(t)
  93. globals.state = globals.State.STOPPED
  94. if globals.air:
  95. await globals.air.disconnect()
  96. @app.exception_handler(404)
  97. async def exception_handler_404(request: Request, exception: Exception) -> Response:
  98. globals.log.warning(f'{request.url} not found')
  99. with Client(page('')) as client:
  100. error_content(404, exception)
  101. return client.build_response(request, 404)
  102. @app.exception_handler(Exception)
  103. async def exception_handler_500(request: Request, exception: Exception) -> Response:
  104. globals.log.exception(exception)
  105. with Client(page('')) as client:
  106. error_content(500, exception)
  107. return client.build_response(request, 500)
  108. @sio.on('handshake')
  109. def on_handshake(sid: str) -> bool:
  110. client = get_client(sid)
  111. if not client:
  112. return False
  113. client.environ = sio.get_environ(sid)
  114. sio.enter_room(sid, client.id)
  115. handle_handshake(client)
  116. return True
  117. def handle_handshake(client: Client) -> None:
  118. for t in client.connect_handlers:
  119. safe_invoke(t, client)
  120. for t in globals.connect_handlers:
  121. safe_invoke(t, client)
  122. @sio.on('disconnect')
  123. def on_disconnect(sid: str) -> None:
  124. client = get_client(sid)
  125. if not client:
  126. return
  127. handle_disconnect(client)
  128. def handle_disconnect(client: Client) -> None:
  129. if not client.shared:
  130. delete_client(client.id)
  131. for t in client.disconnect_handlers:
  132. safe_invoke(t, client)
  133. for t in globals.disconnect_handlers:
  134. safe_invoke(t, client)
  135. @sio.on('event')
  136. def on_event(sid: str, msg: Dict) -> None:
  137. client = get_client(sid)
  138. if not client or not client.has_socket_connection:
  139. return
  140. handle_event(client, msg)
  141. def handle_event(client: Client, msg: Dict) -> None:
  142. with client:
  143. sender = client.elements.get(msg['id'])
  144. if sender:
  145. msg['args'] = [None if arg is None else json.loads(arg) for arg in msg.get('args', [])]
  146. if len(msg['args']) == 1:
  147. msg['args'] = msg['args'][0]
  148. sender._handle_event(msg) # pylint: disable=protected-access
  149. @sio.on('javascript_response')
  150. def on_javascript_response(sid: str, msg: Dict) -> None:
  151. client = get_client(sid)
  152. if not client:
  153. return
  154. handle_javascript_response(client, msg)
  155. def handle_javascript_response(client: Client, msg: Dict) -> None:
  156. client.waiting_javascript_commands[msg['request_id']] = msg['result']
  157. def get_client(sid: str) -> Optional[Client]:
  158. query_bytes: bytearray = sio.get_environ(sid)['asgi.scope']['query_string']
  159. query = urllib.parse.parse_qs(query_bytes.decode())
  160. client_id = query['client_id'][0]
  161. return globals.clients.get(client_id)
  162. async def prune_clients() -> None:
  163. while True:
  164. stale_clients = [
  165. id
  166. for id, client in globals.clients.items()
  167. if not client.shared and not client.has_socket_connection and client.created < time.time() - 60.0
  168. ]
  169. for client_id in stale_clients:
  170. delete_client(client_id)
  171. await asyncio.sleep(10)
  172. async def prune_slot_stacks() -> None:
  173. while True:
  174. running = [
  175. id(task)
  176. for task in asyncio.tasks.all_tasks()
  177. if not task.done() and not task.cancelled()
  178. ]
  179. stale = [
  180. id_
  181. for id_ in globals.slot_stacks
  182. if id_ not in running
  183. ]
  184. for id_ in stale:
  185. del globals.slot_stacks[id_]
  186. await asyncio.sleep(10)
  187. def delete_client(client_id: str) -> None:
  188. binding.remove(list(globals.clients[client_id].elements.values()), Element)
  189. for element in globals.clients[client_id].elements.values():
  190. element.delete()
  191. del globals.clients[client_id]