fastapi.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import os
  2. import asyncio
  3. import json
  4. import logging
  5. from functools import partial
  6. import uvicorn
  7. from starlette.applications import Starlette
  8. from starlette.requests import Request
  9. from starlette.responses import HTMLResponse
  10. from starlette.routing import Route, WebSocketRoute, Mount
  11. from starlette.websockets import WebSocket
  12. from starlette.websockets import WebSocketDisconnect
  13. from .remote_access import start_remote_access_service
  14. from .tornado import open_webbrowser_on_server_started
  15. from .page import make_applications, render_page
  16. from .utils import cdn_validation, OriginChecker, deserialize_binary_event, print_listen_address
  17. from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, Session
  18. from ..session.base import get_session_info_from_headers
  19. from ..utils import get_free_port, STATIC_PATH, iscoroutinefunction, isgeneratorfunction, strip_space
  20. logger = logging.getLogger(__name__)
  21. def _webio_routes(applications, cdn, check_origin_func):
  22. """
  23. :param dict applications: dict of `name -> task function`
  24. :param bool/str cdn: Whether to load front-end static resources from CDN
  25. :param callable check_origin_func: check_origin_func(origin, host) -> bool
  26. """
  27. async def http_endpoint(request: Request):
  28. origin = request.headers.get('origin')
  29. if origin and not check_origin_func(origin=origin, host=request.headers.get('host')):
  30. return HTMLResponse(status_code=403, content="Cross origin websockets not allowed")
  31. # Backward compatible
  32. if request.query_params.get('test'):
  33. return HTMLResponse(content="")
  34. app_name = request.query_params.get('app', 'index')
  35. app = applications.get(app_name) or applications['index']
  36. no_cdn = cdn is True and request.query_params.get('_pywebio_cdn', '') == 'false'
  37. html = render_page(app, protocol='ws', cdn=False if no_cdn else cdn)
  38. return HTMLResponse(content=html)
  39. async def websocket_endpoint(websocket: WebSocket):
  40. ioloop = asyncio.get_event_loop()
  41. await websocket.accept()
  42. close_from_session_tag = False # session close causes websocket close
  43. def send_msg_to_client(session: Session):
  44. for msg in session.get_task_commands():
  45. ioloop.create_task(websocket.send_json(msg))
  46. def close_from_session():
  47. nonlocal close_from_session_tag
  48. close_from_session_tag = True
  49. ioloop.create_task(websocket.close())
  50. logger.debug("WebSocket closed from session")
  51. session_info = get_session_info_from_headers(websocket.headers)
  52. session_info['user_ip'] = websocket.client.host or ''
  53. session_info['request'] = websocket
  54. session_info['backend'] = 'starlette'
  55. session_info['protocol'] = 'websocket'
  56. app_name = websocket.query_params.get('app', 'index')
  57. application = applications.get(app_name) or applications['index']
  58. if iscoroutinefunction(application) or isgeneratorfunction(application):
  59. session = CoroutineBasedSession(application, session_info=session_info,
  60. on_task_command=send_msg_to_client,
  61. on_session_close=close_from_session)
  62. else:
  63. session = ThreadBasedSession(application, session_info=session_info,
  64. on_task_command=send_msg_to_client,
  65. on_session_close=close_from_session, loop=ioloop)
  66. while True:
  67. try:
  68. msg = await websocket.receive()
  69. if msg["type"] == "websocket.disconnect":
  70. raise WebSocketDisconnect(msg["code"])
  71. text, binary = msg.get('text'), msg.get('bytes')
  72. event = None
  73. if text:
  74. event = json.loads(text)
  75. elif binary:
  76. event = deserialize_binary_event(binary)
  77. except WebSocketDisconnect:
  78. if not close_from_session_tag:
  79. # close session because client disconnected to server
  80. session.close(nonblock=True)
  81. logger.debug("WebSocket closed from client")
  82. break
  83. if event is not None:
  84. session.send_client_event(event)
  85. return [
  86. Route("/", http_endpoint),
  87. WebSocketRoute("/", websocket_endpoint)
  88. ]
  89. def webio_routes(applications, cdn=True, allowed_origins=None, check_origin=None):
  90. """Get the FastAPI/Starlette routes for running PyWebIO applications.
  91. The API communicates with the browser using WebSocket protocol.
  92. The arguments of ``webio_routes()`` have the same meaning as for :func:`pywebio.platform.fastapi.start_server`
  93. .. versionadded:: 1.3
  94. :return: FastAPI/Starlette routes
  95. """
  96. try:
  97. import websockets
  98. except Exception:
  99. raise RuntimeError(strip_space("""
  100. Missing dependency package `websockets` for websocket support.
  101. You can install it with the following command:
  102. pip install websockets
  103. """.strip(), n=8)) from None
  104. applications = make_applications(applications)
  105. for target in applications.values():
  106. register_session_implement_for_target(target)
  107. cdn = cdn_validation(cdn, 'error')
  108. if check_origin is None:
  109. check_origin_func = partial(OriginChecker.check_origin, allowed_origins=allowed_origins or [])
  110. else:
  111. check_origin_func = lambda origin, host: OriginChecker.is_same_site(origin, host) or check_origin(origin)
  112. return _webio_routes(applications=applications, cdn=cdn, check_origin_func=check_origin_func)
  113. def start_server(applications, port=0, host='', cdn=True,
  114. static_dir=None, remote_access=False, debug=False,
  115. allowed_origins=None, check_origin=None,
  116. auto_open_webbrowser=False,
  117. **uvicorn_settings):
  118. """Start a FastAPI/Starlette server using uvicorn to provide the PyWebIO application as a web service.
  119. :param bool debug: Boolean indicating if debug tracebacks should be returned on errors.
  120. :param uvicorn_settings: Additional keyword arguments passed to ``uvicorn.run()``.
  121. For details, please refer: https://www.uvicorn.org/settings/
  122. The rest arguments of ``start_server()`` have the same meaning as for :func:`pywebio.platform.tornado.start_server`
  123. .. versionadded:: 1.3
  124. """
  125. app = asgi_app(applications, cdn=cdn, static_dir=static_dir, debug=debug,
  126. allowed_origins=allowed_origins, check_origin=check_origin)
  127. if auto_open_webbrowser:
  128. asyncio.get_event_loop().create_task(open_webbrowser_on_server_started('localhost', port))
  129. if not host:
  130. host = '0.0.0.0'
  131. if port == 0:
  132. port = get_free_port()
  133. print_listen_address(host, port)
  134. if remote_access:
  135. start_remote_access_service(local_port=port)
  136. uvicorn.run(app, host=host, port=port, **uvicorn_settings)
  137. def asgi_app(applications, cdn=True, static_dir=None, debug=False, allowed_origins=None, check_origin=None):
  138. """Get the starlette/Fastapi ASGI app for running PyWebIO applications.
  139. Use :func:`pywebio.platform.fastapi.webio_routes` if you prefer handling static files yourself.
  140. The arguments of ``asgi_app()`` have the same meaning as for :func:`pywebio.platform.fastapi.start_server`
  141. :Example:
  142. To be used with ``FastAPI.mount()`` to include pywebio as a subapp into an existing Starlette/FastAPI application::
  143. from fastapi import FastAPI
  144. from pywebio.platform.fastapi import asgi_app
  145. from pywebio.output import put_text
  146. app = FastAPI()
  147. subapp = asgi_app(lambda: put_text("hello from pywebio"))
  148. app.mount("/pywebio", subapp)
  149. :Returns: Starlette/Fastapi ASGI app
  150. .. versionadded:: 1.3
  151. """
  152. try:
  153. from starlette.staticfiles import StaticFiles
  154. except Exception:
  155. raise RuntimeError(strip_space("""
  156. Missing dependency package `aiofiles` for static file serving.
  157. You can install it with the following command:
  158. pip install aiofiles
  159. """.strip(), n=8)) from None
  160. debug = Session.debug = os.environ.get('PYWEBIO_DEBUG', debug)
  161. cdn = cdn_validation(cdn, 'warn')
  162. if cdn is False:
  163. cdn = 'pywebio_static'
  164. routes = webio_routes(applications, cdn=cdn, allowed_origins=allowed_origins, check_origin=check_origin)
  165. if static_dir:
  166. routes.append(Mount('/static', app=StaticFiles(directory=static_dir), name="static"))
  167. routes.append(Mount('/pywebio_static', app=StaticFiles(directory=STATIC_PATH), name="pywebio_static"))
  168. return Starlette(routes=routes, debug=debug)