aiohttp.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import asyncio
  2. import fnmatch
  3. import json
  4. import logging
  5. from functools import partial
  6. from os import path, listdir
  7. from urllib.parse import urlparse
  8. from aiohttp import web
  9. from .tornado import open_webbrowser_on_server_started
  10. from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, Session
  11. from ..session.base import get_session_info_from_headers
  12. from ..utils import get_free_port, STATIC_PATH
  13. logger = logging.getLogger(__name__)
  14. def _check_origin(origin, allowed_origins, host):
  15. if _is_same_site(origin, host):
  16. return True
  17. return any(
  18. fnmatch.fnmatch(origin, patten)
  19. for patten in allowed_origins
  20. )
  21. def _is_same_site(origin, host):
  22. """判断 origin 和 host 是否一致。origin 和 host 都为http协议请求头"""
  23. parsed_origin = urlparse(origin)
  24. origin = parsed_origin.netloc
  25. origin = origin.lower()
  26. # Check to see that origin matches host directly, including ports
  27. return origin == host
  28. def _webio_handler(target, session_cls, websocket_settings, check_origin_func=_is_same_site):
  29. """获取用于Tornado进行整合的RequestHandle类
  30. :param target: 任务函数
  31. :param session_cls: 会话实现类
  32. :param callable check_origin_func: check_origin_func(origin, handler) -> bool
  33. :return: Tornado RequestHandle类
  34. """
  35. ioloop = asyncio.get_event_loop()
  36. async def wshandle(request: web.Request):
  37. origin = request.headers.get('origin')
  38. if origin and not check_origin_func(origin=origin, host=request.host):
  39. return web.Response(status=403, text="Cross origin websockets not allowed")
  40. ws = web.WebSocketResponse(**websocket_settings)
  41. await ws.prepare(request)
  42. close_from_session_tag = False # 是否由session主动关闭连接
  43. def send_msg_to_client(session: Session):
  44. for msg in session.get_task_commands():
  45. msg_str = json.dumps(msg)
  46. ioloop.create_task(ws.send_str(msg_str))
  47. def close_from_session():
  48. nonlocal close_from_session_tag
  49. close_from_session_tag = True
  50. ioloop.create_task(ws.close())
  51. logger.debug("WebSocket closed from session")
  52. session_info = get_session_info_from_headers(request.headers)
  53. session_info['user_ip'] = request.remote
  54. session_info['request'] = request
  55. session_info['backend'] = 'aiohttp'
  56. if session_cls is CoroutineBasedSession:
  57. session = CoroutineBasedSession(target, session_info=session_info,
  58. on_task_command=send_msg_to_client,
  59. on_session_close=close_from_session)
  60. elif session_cls is ThreadBasedSession:
  61. session = ThreadBasedSession(target, session_info=session_info,
  62. on_task_command=send_msg_to_client,
  63. on_session_close=close_from_session, loop=ioloop)
  64. else:
  65. raise RuntimeError("Don't support session type:%s" % session_cls)
  66. async for msg in ws:
  67. if msg.type == web.WSMsgType.text:
  68. data = msg.json()
  69. if data is not None:
  70. session.send_client_event(data)
  71. elif msg.type == web.WSMsgType.binary:
  72. pass
  73. elif msg.type == web.WSMsgType.close:
  74. if not close_from_session_tag:
  75. session.close()
  76. logger.debug("WebSocket closed from client")
  77. return ws
  78. return wshandle
  79. def webio_handler(target, allowed_origins=None, check_origin=None, websocket_settings=None):
  80. """获取在aiohttp中运行PyWebIO任务函数的 `Request Handle <https://docs.aiohttp.org/en/stable/web_quickstart.html#aiohttp-web-handler>`_ 协程。
  81. Request Handle基于WebSocket协议与浏览器进行通讯。
  82. :param target: 任务函数。任务函数为协程函数时,使用 :ref:`基于协程的会话实现 <coroutine_based_session>` ;任务函数为普通函数时,使用基于线程的会话实现。
  83. :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
  84. 来源包含协议和域名和端口部分,允许使用 Unix shell 风格的匹配模式:
  85. - ``*`` 为通配符
  86. - ``?`` 匹配单个字符
  87. - ``[seq]`` 匹配seq内的字符
  88. - ``[!seq]`` 匹配不在seq内的字符
  89. 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
  90. :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
  91. 返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
  92. :param dict websocket_settings: 创建 aiohttp WebSocketResponse 时使用的参数。见 https://docs.aiohttp.org/en/stable/web_reference.html#websocketresponse
  93. :return: aiohttp Request Handler
  94. """
  95. session_cls = register_session_implement_for_target(target)
  96. websocket_settings = websocket_settings or {}
  97. if check_origin is None:
  98. check_origin_func = partial(_check_origin, allowed_origins=allowed_origins or [])
  99. else:
  100. check_origin_func = lambda origin, handler: _is_same_site(origin, handler) or check_origin(origin)
  101. return _webio_handler(target=target, session_cls=session_cls, check_origin_func=check_origin_func,
  102. websocket_settings=websocket_settings)
  103. def static_routes(static_path):
  104. """获取用于提供PyWebIO静态文件的aiohttp路由"""
  105. async def index(request):
  106. return web.FileResponse(path.join(STATIC_PATH, 'index.html'))
  107. files = [path.join(static_path, d) for d in listdir(static_path)]
  108. dirs = filter(path.isdir, files)
  109. routes = [web.static('/' + path.basename(d), d) for d in dirs]
  110. routes.append(web.get('/', index))
  111. return routes
  112. def start_server(target, port=0, host='', debug=False,
  113. allowed_origins=None, check_origin=None,
  114. auto_open_webbrowser=False,
  115. websocket_settings=None,
  116. **aiohttp_settings):
  117. """启动一个 aiohttp server 将 ``target`` 任务函数作为Web服务提供。
  118. :param target: 任务函数。任务函数为协程函数时,使用 :ref:`基于协程的会话实现 <coroutine_based_session>` ;任务函数为普通函数时,使用基于线程的会话实现。
  119. :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
  120. :param int port: server bind port. set ``0`` to find a free port number to use
  121. :param str host: server bind host. ``host`` may be either an IP address or hostname. If it's a hostname,
  122. the server will listen on all IP addresses associated with the name.
  123. set empty string or to listen on all available interfaces.
  124. :param bool debug: asyncio Debug Mode
  125. :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
  126. 来源包含协议和域名和端口部分,允许使用 Unix shell 风格的匹配模式:
  127. - ``*`` 为通配符
  128. - ``?`` 匹配单个字符
  129. - ``[seq]`` 匹配seq内的字符
  130. - ``[!seq]`` 匹配不在seq内的字符
  131. 比如 ``https://*.example.com`` 、 ``*://*.example.com``
  132. :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
  133. 返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
  134. :param bool auto_open_webbrowser: Whether or not auto open web browser when server is started (if the operating system allows it) .
  135. :param dict websocket_settings: 创建 aiohttp WebSocketResponse 时使用的参数。见 https://docs.aiohttp.org/en/stable/web_reference.html#websocketresponse
  136. :param aiohttp_settings: 需要传给 aiohttp Application 的参数。可用参数见 https://docs.aiohttp.org/en/stable/web_reference.html#application
  137. """
  138. kwargs = locals()
  139. if not host:
  140. host = '0.0.0.0'
  141. if port == 0:
  142. port = get_free_port()
  143. handler = webio_handler(target, allowed_origins=allowed_origins, check_origin=check_origin,
  144. websocket_settings=websocket_settings)
  145. app = web.Application(**aiohttp_settings)
  146. app.add_routes([web.get('/io', handler)])
  147. app.add_routes(static_routes(STATIC_PATH))
  148. if auto_open_webbrowser:
  149. asyncio.get_event_loop().create_task(open_webbrowser_on_server_started('localhost', port))
  150. if debug:
  151. logging.getLogger("asyncio").setLevel(logging.DEBUG)
  152. print('Listen on %s:%s' % (host, port))
  153. web.run_app(app, host=host, port=port)