aiohttp.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import asyncio
  2. import fnmatch
  3. import json
  4. import logging
  5. import os
  6. import typing
  7. from functools import partial
  8. from urllib.parse import urlparse
  9. from aiohttp import web
  10. from . import page
  11. from .adaptor import ws as ws_adaptor
  12. from .page import make_applications, render_page
  13. from .remote_access import start_remote_access_service
  14. from .tornado import open_webbrowser_on_server_started
  15. from .utils import cdn_validation, print_listen_address
  16. from ..session import register_session_implement_for_target, Session
  17. from ..session.base import get_session_info_from_headers
  18. from ..utils import get_free_port, STATIC_PATH, parse_file_size
  19. logger = logging.getLogger(__name__)
  20. def _check_origin(origin, allowed_origins, host):
  21. if _is_same_site(origin, host):
  22. return True
  23. return any(
  24. fnmatch.fnmatch(origin, pattern)
  25. for pattern in allowed_origins
  26. )
  27. def _is_same_site(origin, host):
  28. """判断 origin 和 host 是否一致。origin 和 host 都为http协议请求头"""
  29. parsed_origin = urlparse(origin)
  30. origin = parsed_origin.netloc
  31. origin = origin.lower()
  32. # Check to see that origin matches host directly, including ports
  33. return origin == host
  34. class WebSocketConnection(ws_adaptor.WebSocketConnection):
  35. def __init__(self, ws: web.WebSocketResponse, http: web.Request, ioloop):
  36. self.ws = ws
  37. self.http = http
  38. self.ioloop = ioloop
  39. def get_query_argument(self, name) -> typing.Optional[str]:
  40. return self.http.query.getone(name, None)
  41. def make_session_info(self) -> dict:
  42. session_info = get_session_info_from_headers(self.http.headers)
  43. session_info['user_ip'] = self.http.remote
  44. session_info['request'] = self.http
  45. session_info['backend'] = 'aiohttp'
  46. session_info['protocol'] = 'websocket'
  47. return session_info
  48. def write_message(self, message: dict):
  49. msg_str = json.dumps(message)
  50. self.ioloop.create_task(self.ws.send_str(msg_str))
  51. def closed(self) -> bool:
  52. return self.ws.closed
  53. def close(self):
  54. self.ioloop.create_task(self.ws.close())
  55. def _webio_handler(applications, cdn, websocket_settings, reconnect_timeout=0, check_origin_func=_is_same_site):
  56. """
  57. :param dict applications: dict of `name -> task function`
  58. :param bool/str cdn: Whether to load front-end static resources from CDN
  59. :param callable check_origin_func: check_origin_func(origin, host) -> bool
  60. :return: aiohttp Request Handler
  61. """
  62. ws_adaptor.set_expire_second(reconnect_timeout)
  63. async def wshandle(request: web.Request):
  64. ioloop = asyncio.get_event_loop()
  65. asyncio.get_event_loop().create_task(ws_adaptor.session_clean_task())
  66. origin = request.headers.get('origin')
  67. if origin and not check_origin_func(origin=origin, host=request.host):
  68. return web.Response(status=403, text="Cross origin websockets not allowed")
  69. if request.headers.get("Upgrade", "").lower() != "websocket":
  70. # Backward compatible
  71. if request.query.getone('test', ''):
  72. return web.Response(text="")
  73. app_name = request.query.getone('app', 'index')
  74. app = applications.get(app_name) or applications['index']
  75. no_cdn = cdn is True and request.query.getone('_pywebio_cdn', '') == 'false'
  76. html = render_page(app, protocol='ws', cdn=False if no_cdn else cdn)
  77. return web.Response(body=html, content_type='text/html')
  78. ws = web.WebSocketResponse(**websocket_settings)
  79. await ws.prepare(request)
  80. app_name = request.query.getone('app', 'index')
  81. application = applications.get(app_name) or applications['index']
  82. conn = WebSocketConnection(ws, request, ioloop)
  83. handler = ws_adaptor.WebSocketHandler(
  84. connection=conn, application=application, reconnectable=bool(reconnect_timeout), ioloop=ioloop
  85. )
  86. # see: https://github.com/aio-libs/aiohttp/issues/1768
  87. try:
  88. async for msg in ws:
  89. if msg.type in (web.WSMsgType.text, web.WSMsgType.binary):
  90. handler.send_client_data(msg.data)
  91. elif msg.type == web.WSMsgType.close:
  92. raise asyncio.CancelledError()
  93. finally:
  94. handler.notify_connection_lost()
  95. return ws
  96. return wshandle
  97. def webio_handler(applications, cdn=True, reconnect_timeout=0, allowed_origins=None, check_origin=None,
  98. max_payload_size='200M', websocket_settings=None):
  99. """Get the `Request Handler <https://docs.aiohttp.org/en/stable/web_quickstart.html#aiohttp-web-handler>`_ coroutine for running PyWebIO applications in aiohttp.
  100. The handler communicates with the browser by WebSocket protocol.
  101. The arguments of ``webio_handler()`` have the same meaning as for :func:`pywebio.platform.aiohttp.start_server`
  102. :return: aiohttp Request Handler
  103. """
  104. applications = make_applications(applications)
  105. for target in applications.values():
  106. register_session_implement_for_target(target)
  107. websocket_settings = websocket_settings or {}
  108. page.MAX_PAYLOAD_SIZE = max_payload_size = parse_file_size(max_payload_size)
  109. websocket_settings.setdefault('max_msg_size', max_payload_size)
  110. cdn = cdn_validation(cdn, 'error')
  111. if check_origin is None:
  112. check_origin_func = partial(_check_origin, allowed_origins=allowed_origins or [])
  113. else:
  114. check_origin_func = lambda origin, host: _is_same_site(origin, host) or check_origin(origin)
  115. return _webio_handler(applications=applications, cdn=cdn,
  116. check_origin_func=check_origin_func,
  117. reconnect_timeout=reconnect_timeout,
  118. websocket_settings=websocket_settings)
  119. def static_routes(prefix='/'):
  120. """获取用于提供PyWebIO静态文件的aiohttp路由列表
  121. Get the aiohttp routes list for PyWebIO static files hosting.
  122. :param str prefix: The URL path of static file hosting, the default is the root path ``/``
  123. :return: aiohttp routes list
  124. """
  125. files = [os.path.join(STATIC_PATH, d) for d in os.listdir(STATIC_PATH)]
  126. dirs = filter(os.path.isdir, files)
  127. routes = [web.static(prefix + os.path.basename(d), d) for d in dirs]
  128. return routes
  129. def start_server(applications, port=0, host='', debug=False,
  130. cdn=True, static_dir=None, remote_access=False,
  131. reconnect_timeout=0,
  132. allowed_origins=None, check_origin=None,
  133. auto_open_webbrowser=False,
  134. max_payload_size='200M',
  135. websocket_settings=None,
  136. **aiohttp_settings):
  137. """Start a aiohttp server to provide the PyWebIO application as a web service.
  138. :param dict websocket_settings: The parameters passed to the constructor of ``aiohttp.web.WebSocketResponse``.
  139. For details, please refer: https://docs.aiohttp.org/en/stable/web_reference.html#websocketresponse
  140. :param aiohttp_settings: Additional keyword arguments passed to the constructor of ``aiohttp.web.Application``.
  141. For details, please refer: https://docs.aiohttp.org/en/stable/web_reference.html#application
  142. The rest arguments of ``start_server()`` have the same meaning as for :func:`pywebio.platform.tornado.start_server`
  143. """
  144. kwargs = locals()
  145. if not host:
  146. host = '0.0.0.0'
  147. if port == 0:
  148. port = get_free_port()
  149. cdn = cdn_validation(cdn, 'warn')
  150. handler = webio_handler(applications, cdn=cdn, allowed_origins=allowed_origins, reconnect_timeout=reconnect_timeout,
  151. check_origin=check_origin, max_payload_size=max_payload_size,
  152. websocket_settings=websocket_settings)
  153. app = web.Application(**aiohttp_settings)
  154. app.router.add_routes([web.get('/', handler)])
  155. if static_dir is not None:
  156. app.router.add_routes([web.static('/static', static_dir)])
  157. app.router.add_routes(static_routes())
  158. if auto_open_webbrowser:
  159. asyncio.get_event_loop().create_task(open_webbrowser_on_server_started('127.0.0.1', port))
  160. debug = Session.debug = os.environ.get('PYWEBIO_DEBUG', debug)
  161. if debug:
  162. logging.getLogger("asyncio").setLevel(logging.DEBUG)
  163. print_listen_address(host, port)
  164. if remote_access:
  165. start_remote_access_service(local_port=port)
  166. web.run_app(app, host=host, port=port)