aiohttp.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import asyncio
  2. import fnmatch
  3. import json
  4. import logging
  5. from functools import partial
  6. from os import path, listdir
  7. from urllib.parse import urlparse
  8. from aiohttp import web
  9. from .remote_access import start_remote_access_service
  10. from .tornado import open_webbrowser_on_server_started
  11. from .utils import make_applications, render_page, cdn_validation, deserialize_binary_event
  12. from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, Session
  13. from ..session.base import get_session_info_from_headers
  14. from ..utils import get_free_port, STATIC_PATH, iscoroutinefunction, isgeneratorfunction
  15. logger = logging.getLogger(__name__)
  16. def _check_origin(origin, allowed_origins, host):
  17. if _is_same_site(origin, host):
  18. return True
  19. return any(
  20. fnmatch.fnmatch(origin, patten)
  21. for patten in allowed_origins
  22. )
  23. def _is_same_site(origin, host):
  24. """判断 origin 和 host 是否一致。origin 和 host 都为http协议请求头"""
  25. parsed_origin = urlparse(origin)
  26. origin = parsed_origin.netloc
  27. origin = origin.lower()
  28. # Check to see that origin matches host directly, including ports
  29. return origin == host
  30. def _webio_handler(applications, cdn, websocket_settings, check_origin_func=_is_same_site):
  31. """
  32. :param dict applications: dict of `name -> task function`
  33. :param bool/str cdn: Whether to load front-end static resources from CDN
  34. :param callable check_origin_func: check_origin_func(origin, host) -> bool
  35. :return: aiohttp Request Handler
  36. """
  37. ioloop = asyncio.get_event_loop()
  38. async def wshandle(request: web.Request):
  39. origin = request.headers.get('origin')
  40. if origin and not check_origin_func(origin=origin, host=request.host):
  41. return web.Response(status=403, text="Cross origin websockets not allowed")
  42. if request.headers.get("Upgrade", "").lower() != "websocket":
  43. # Backward compatible
  44. if request.query.getone('test', ''):
  45. return web.Response(text="")
  46. app_name = request.query.getone('app', 'index')
  47. app = applications.get(app_name) or applications['index']
  48. html = render_page(app, protocol='ws', cdn=cdn)
  49. return web.Response(body=html, content_type='text/html')
  50. ws = web.WebSocketResponse(**websocket_settings)
  51. await ws.prepare(request)
  52. close_from_session_tag = False # 是否由session主动关闭连接
  53. def send_msg_to_client(session: Session):
  54. for msg in session.get_task_commands():
  55. msg_str = json.dumps(msg)
  56. ioloop.create_task(ws.send_str(msg_str))
  57. def close_from_session():
  58. nonlocal close_from_session_tag
  59. close_from_session_tag = True
  60. ioloop.create_task(ws.close())
  61. logger.debug("WebSocket closed from session")
  62. session_info = get_session_info_from_headers(request.headers)
  63. session_info['user_ip'] = request.remote
  64. session_info['request'] = request
  65. session_info['backend'] = 'aiohttp'
  66. session_info['protocol'] = 'websocket'
  67. app_name = request.query.getone('app', 'index')
  68. application = applications.get(app_name) or applications['index']
  69. if iscoroutinefunction(application) or isgeneratorfunction(application):
  70. session = CoroutineBasedSession(application, session_info=session_info,
  71. on_task_command=send_msg_to_client,
  72. on_session_close=close_from_session)
  73. else:
  74. session = ThreadBasedSession(application, session_info=session_info,
  75. on_task_command=send_msg_to_client,
  76. on_session_close=close_from_session, loop=ioloop)
  77. # see: https://github.com/aio-libs/aiohttp/issues/1768
  78. try:
  79. async for msg in ws:
  80. if msg.type == web.WSMsgType.text:
  81. data = msg.json()
  82. elif msg.type == web.WSMsgType.binary:
  83. data = deserialize_binary_event(msg.data)
  84. elif msg.type == web.WSMsgType.close:
  85. raise asyncio.CancelledError()
  86. if data is not None:
  87. session.send_client_event(data)
  88. finally:
  89. if not close_from_session_tag:
  90. # close session because client disconnected to server
  91. session.close(nonblock=True)
  92. logger.debug("WebSocket closed from client")
  93. return ws
  94. return wshandle
  95. def webio_handler(applications, cdn=True, allowed_origins=None, check_origin=None, websocket_settings=None):
  96. """Get the `Request Handler <https://docs.aiohttp.org/en/stable/web_quickstart.html#aiohttp-web-handler>`_ coroutine for running PyWebIO applications in aiohttp.
  97. The handler communicates with the browser by WebSocket protocol.
  98. The arguments of ``webio_handler()`` have the same meaning as for :func:`pywebio.platform.aiohttp.start_server`
  99. :return: aiohttp Request Handler
  100. """
  101. applications = make_applications(applications)
  102. for target in applications.values():
  103. register_session_implement_for_target(target)
  104. websocket_settings = websocket_settings or {}
  105. cdn = cdn_validation(cdn, 'error')
  106. if check_origin is None:
  107. check_origin_func = partial(_check_origin, allowed_origins=allowed_origins or [])
  108. else:
  109. check_origin_func = lambda origin, host: _is_same_site(origin, host) or check_origin(origin)
  110. return _webio_handler(applications=applications, cdn=cdn,
  111. check_origin_func=check_origin_func,
  112. websocket_settings=websocket_settings)
  113. def static_routes(prefix='/'):
  114. """获取用于提供PyWebIO静态文件的aiohttp路由列表
  115. Get the aiohttp routes list for PyWebIO static files hosting.
  116. :param str prefix: The URL path of static file hosting, the default is the root path ``/``
  117. :return: aiohttp routes list
  118. """
  119. async def index(request):
  120. return web.FileResponse(path.join(STATIC_PATH, 'index.html'))
  121. files = [path.join(STATIC_PATH, d) for d in listdir(STATIC_PATH)]
  122. dirs = filter(path.isdir, files)
  123. routes = [web.static(prefix + path.basename(d), d) for d in dirs]
  124. routes.append(web.get(prefix, index))
  125. return routes
  126. def start_server(applications, port=0, host='', debug=False,
  127. cdn=True, static_dir=None, remote_access=False,
  128. allowed_origins=None, check_origin=None,
  129. auto_open_webbrowser=False,
  130. websocket_settings=None,
  131. **aiohttp_settings):
  132. """Start a aiohttp server to provide the PyWebIO application as a web service.
  133. :param dict websocket_settings: The parameters passed to the constructor of ``aiohttp.web.WebSocketResponse``.
  134. For details, please refer: https://docs.aiohttp.org/en/stable/web_reference.html#websocketresponse
  135. :param aiohttp_settings: Additional keyword arguments passed to the constructor of ``aiohttp.web.Application``.
  136. For details, please refer: https://docs.aiohttp.org/en/stable/web_reference.html#application
  137. The rest arguments of ``start_server()`` have the same meaning as for :func:`pywebio.platform.tornado.start_server`
  138. """
  139. kwargs = locals()
  140. if not host:
  141. host = '0.0.0.0'
  142. if port == 0:
  143. port = get_free_port()
  144. cdn = cdn_validation(cdn, 'warn')
  145. handler = webio_handler(applications, cdn=cdn, allowed_origins=allowed_origins,
  146. check_origin=check_origin, websocket_settings=websocket_settings)
  147. app = web.Application(**aiohttp_settings)
  148. app.router.add_routes([web.get('/', handler)])
  149. if static_dir is not None:
  150. app.router.add_routes([web.static('/static', static_dir)])
  151. app.router.add_routes(static_routes())
  152. if auto_open_webbrowser:
  153. asyncio.get_event_loop().create_task(open_webbrowser_on_server_started('localhost', port))
  154. if debug:
  155. logging.getLogger("asyncio").setLevel(logging.DEBUG)
  156. print('Listen on %s:%s' % (host, port))
  157. if remote_access or remote_access == {}:
  158. if remote_access is True: remote_access = {}
  159. start_remote_access_service(**remote_access, local_port=port)
  160. web.run_app(app, host=host, port=port)