aiohttp.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import asyncio
  2. import fnmatch
  3. import json
  4. import logging
  5. from functools import partial
  6. from os import path, listdir
  7. from urllib.parse import urlparse
  8. from aiohttp import web
  9. from .tornado import open_webbrowser_on_server_started
  10. from .utils import make_applications, render_page
  11. from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, Session
  12. from ..session.base import get_session_info_from_headers
  13. from ..utils import get_free_port, STATIC_PATH, iscoroutinefunction, isgeneratorfunction
  14. logger = logging.getLogger(__name__)
  15. def _check_origin(origin, allowed_origins, host):
  16. if _is_same_site(origin, host):
  17. return True
  18. return any(
  19. fnmatch.fnmatch(origin, patten)
  20. for patten in allowed_origins
  21. )
  22. def _is_same_site(origin, host):
  23. """判断 origin 和 host 是否一致。origin 和 host 都为http协议请求头"""
  24. parsed_origin = urlparse(origin)
  25. origin = parsed_origin.netloc
  26. origin = origin.lower()
  27. # Check to see that origin matches host directly, including ports
  28. return origin == host
  29. def _webio_handler(applications, websocket_settings, check_origin_func=_is_same_site):
  30. """获取用于Tornado进行整合的RequestHandle类
  31. :param dict applications: 任务名->任务函数 的映射
  32. :param callable check_origin_func: check_origin_func(origin, handler) -> bool
  33. :return: Tornado RequestHandle类
  34. """
  35. ioloop = asyncio.get_event_loop()
  36. async def wshandle(request: web.Request):
  37. origin = request.headers.get('origin')
  38. if origin and not check_origin_func(origin=origin, host=request.host):
  39. return web.Response(status=403, text="Cross origin websockets not allowed")
  40. if request.headers.get("Upgrade", "").lower() != "websocket":
  41. # Backward compatible
  42. if request.query.getone('test', ''):
  43. return web.Response(text="")
  44. app_name = request.query.getone('app', 'index')
  45. app = applications.get(app_name) or applications['index']
  46. html = render_page(app, protocol='ws')
  47. return web.Response(body=html, content_type='text/html')
  48. ws = web.WebSocketResponse(**websocket_settings)
  49. await ws.prepare(request)
  50. close_from_session_tag = False # 是否由session主动关闭连接
  51. def send_msg_to_client(session: Session):
  52. for msg in session.get_task_commands():
  53. msg_str = json.dumps(msg)
  54. ioloop.create_task(ws.send_str(msg_str))
  55. def close_from_session():
  56. nonlocal close_from_session_tag
  57. close_from_session_tag = True
  58. ioloop.create_task(ws.close())
  59. logger.debug("WebSocket closed from session")
  60. session_info = get_session_info_from_headers(request.headers)
  61. session_info['user_ip'] = request.remote
  62. session_info['request'] = request
  63. session_info['backend'] = 'aiohttp'
  64. app_name = request.query.getone('app', 'index')
  65. application = applications.get(app_name) or applications['index']
  66. if iscoroutinefunction(application) or isgeneratorfunction(application):
  67. session = CoroutineBasedSession(application, session_info=session_info,
  68. on_task_command=send_msg_to_client,
  69. on_session_close=close_from_session)
  70. else:
  71. session = ThreadBasedSession(application, session_info=session_info,
  72. on_task_command=send_msg_to_client,
  73. on_session_close=close_from_session, loop=ioloop)
  74. async for msg in ws:
  75. if msg.type == web.WSMsgType.text:
  76. data = msg.json()
  77. if data is not None:
  78. session.send_client_event(data)
  79. elif msg.type == web.WSMsgType.binary:
  80. pass
  81. elif msg.type == web.WSMsgType.close:
  82. if not close_from_session_tag:
  83. session.close()
  84. logger.debug("WebSocket closed from client")
  85. return ws
  86. return wshandle
  87. def webio_handler(applications, allowed_origins=None, check_origin=None, websocket_settings=None):
  88. """获取在aiohttp中运行PyWebIO任务函数的 `Request Handler <https://docs.aiohttp.org/en/stable/web_quickstart.html#aiohttp-web-handler>`_ 协程。
  89. Request Handler基于WebSocket协议与浏览器进行通讯。
  90. :param list/dict/callable applications: PyWebIO应用。
  91. :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
  92. :param callable check_origin: 请求来源检查函数。
  93. :param dict websocket_settings: 创建 aiohttp WebSocketResponse 时使用的参数。见 https://docs.aiohttp.org/en/stable/web_reference.html#websocketresponse
  94. 关于 ``applications`` 、 ``allowed_origins`` 、 ``check_origin`` 参数的详细说明见 :func:`pywebio.platform.aiohttp.start_server` 的同名参数。
  95. :return: aiohttp Request Handler
  96. """
  97. applications = make_applications(applications)
  98. for target in applications.values():
  99. register_session_implement_for_target(target)
  100. websocket_settings = websocket_settings or {}
  101. if check_origin is None:
  102. check_origin_func = partial(_check_origin, allowed_origins=allowed_origins or [])
  103. else:
  104. check_origin_func = lambda origin, handler: _is_same_site(origin, handler) or check_origin(origin)
  105. return _webio_handler(applications=applications,
  106. check_origin_func=check_origin_func,
  107. websocket_settings=websocket_settings)
  108. def static_routes(prefix='/'):
  109. """获取用于提供PyWebIO静态文件的aiohttp路由列表
  110. :param str prefix: 静态文件托管的URL路径,默认为根路径 ``/``
  111. :return: aiohttp路由列表
  112. """
  113. async def index(request):
  114. return web.FileResponse(path.join(STATIC_PATH, 'index.html'))
  115. files = [path.join(STATIC_PATH, d) for d in listdir(STATIC_PATH)]
  116. dirs = filter(path.isdir, files)
  117. routes = [web.static(prefix + path.basename(d), d) for d in dirs]
  118. routes.append(web.get(prefix, index))
  119. return routes
  120. def start_server(applications, port=0, host='', debug=False,
  121. allowed_origins=None, check_origin=None,
  122. auto_open_webbrowser=False,
  123. websocket_settings=None,
  124. **aiohttp_settings):
  125. """启动一个 aiohttp server 将PyWebIO应用作为Web服务提供。
  126. :param list/dict/callable applications: PyWebIO应用. 格式同 :func:`pywebio.platform.tornado.start_server` 的 ``applications`` 参数
  127. :param int port: 服务监听的端口。设置为 ``0`` 时,表示自动选择可用端口。
  128. :param str host: 服务绑定的地址。 ``host`` 可以是IP地址或者为hostname。如果为hostname,服务会监听所有与该hostname关联的IP地址。
  129. 通过设置 ``host`` 为空字符串或 ``None`` 来将服务绑定到所有可用的地址上。
  130. :param bool debug: 是否开启asyncio的Debug模式
  131. :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。格式同 :func:`pywebio.platform.tornado.start_server` 的 ``allowed_origins`` 参数
  132. :param callable check_origin: 请求来源检查函数。格式同 :func:`pywebio.platform.tornado.start_server` 的 ``check_origin`` 参数
  133. :param bool auto_open_webbrowser: 当服务启动后,是否自动打开浏览器来访问服务。(该操作需要操作系统支持)
  134. :param dict websocket_settings: 创建 aiohttp WebSocketResponse 时使用的参数。见 https://docs.aiohttp.org/en/stable/web_reference.html#websocketresponse
  135. :param aiohttp_settings: 需要传给 aiohttp Application 的参数。可用参数见 https://docs.aiohttp.org/en/stable/web_reference.html#application
  136. """
  137. kwargs = locals()
  138. if not host:
  139. host = '0.0.0.0'
  140. if port == 0:
  141. port = get_free_port()
  142. handler = webio_handler(applications, allowed_origins=allowed_origins, check_origin=check_origin,
  143. websocket_settings=websocket_settings)
  144. app = web.Application(**aiohttp_settings)
  145. app.router.add_routes([web.get('/', handler)])
  146. app.router.add_routes(static_routes())
  147. if auto_open_webbrowser:
  148. asyncio.get_event_loop().create_task(open_webbrowser_on_server_started('localhost', port))
  149. if debug:
  150. logging.getLogger("asyncio").setLevel(logging.DEBUG)
  151. print('Listen on %s:%s' % (host, port))
  152. web.run_app(app, host=host, port=port)