123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- import asyncio
- import fnmatch
- import json
- import logging
- from functools import partial
- from os import path, listdir
- from urllib.parse import urlparse
- from aiohttp import web
- from .tornado import open_webbrowser_on_server_started
- from .utils import make_applications
- from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, Session
- from ..session.base import get_session_info_from_headers
- from ..utils import get_free_port, STATIC_PATH, iscoroutinefunction, isgeneratorfunction
- logger = logging.getLogger(__name__)
- def _check_origin(origin, allowed_origins, host):
- if _is_same_site(origin, host):
- return True
- return any(
- fnmatch.fnmatch(origin, patten)
- for patten in allowed_origins
- )
- def _is_same_site(origin, host):
- """判断 origin 和 host 是否一致。origin 和 host 都为http协议请求头"""
- parsed_origin = urlparse(origin)
- origin = parsed_origin.netloc
- origin = origin.lower()
- # Check to see that origin matches host directly, including ports
- return origin == host
- def _webio_handler(applications, websocket_settings, check_origin_func=_is_same_site):
- """获取用于Tornado进行整合的RequestHandle类
- :param dict applications: 任务名->任务函数 的映射
- :param callable check_origin_func: check_origin_func(origin, handler) -> bool
- :return: Tornado RequestHandle类
- """
- ioloop = asyncio.get_event_loop()
- async def wshandle(request: web.Request):
- origin = request.headers.get('origin')
- if origin and not check_origin_func(origin=origin, host=request.host):
- return web.Response(status=403, text="Cross origin websockets not allowed")
- ws = web.WebSocketResponse(**websocket_settings)
- await ws.prepare(request)
- close_from_session_tag = False # 是否由session主动关闭连接
- def send_msg_to_client(session: Session):
- for msg in session.get_task_commands():
- msg_str = json.dumps(msg)
- ioloop.create_task(ws.send_str(msg_str))
- def close_from_session():
- nonlocal close_from_session_tag
- close_from_session_tag = True
- ioloop.create_task(ws.close())
- logger.debug("WebSocket closed from session")
- session_info = get_session_info_from_headers(request.headers)
- session_info['user_ip'] = request.remote
- session_info['request'] = request
- session_info['backend'] = 'aiohttp'
- app_name = request.query.getone('app', 'index')
- application = applications.get(app_name) or applications['index']
- if iscoroutinefunction(application) or isgeneratorfunction(application):
- session = CoroutineBasedSession(application, session_info=session_info,
- on_task_command=send_msg_to_client,
- on_session_close=close_from_session)
- else:
- session = ThreadBasedSession(application, session_info=session_info,
- on_task_command=send_msg_to_client,
- on_session_close=close_from_session, loop=ioloop)
- async for msg in ws:
- if msg.type == web.WSMsgType.text:
- data = msg.json()
- if data is not None:
- session.send_client_event(data)
- elif msg.type == web.WSMsgType.binary:
- pass
- elif msg.type == web.WSMsgType.close:
- if not close_from_session_tag:
- session.close()
- logger.debug("WebSocket closed from client")
- return ws
- return wshandle
- def webio_handler(applications, allowed_origins=None, check_origin=None, websocket_settings=None):
- """获取在aiohttp中运行PyWebIO任务函数的 `Request Handle <https://docs.aiohttp.org/en/stable/web_quickstart.html#aiohttp-web-handler>`_ 协程。
- Request Handle基于WebSocket协议与浏览器进行通讯。
- :param list/dict/callable applications: PyWebIO应用. 可以是任务函数或者任务函数的字典或列表。
- :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
- 来源包含协议和域名和端口部分,允许使用 Unix shell 风格的匹配模式:
- - ``*`` 为通配符
- - ``?`` 匹配单个字符
- - ``[seq]`` 匹配seq内的字符
- - ``[!seq]`` 匹配不在seq内的字符
- 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
- :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
- 返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
- :param dict websocket_settings: 创建 aiohttp WebSocketResponse 时使用的参数。见 https://docs.aiohttp.org/en/stable/web_reference.html#websocketresponse
- :return: aiohttp Request Handler
- """
- applications = make_applications(applications)
- for target in applications.values():
- register_session_implement_for_target(target)
- websocket_settings = websocket_settings or {}
- if check_origin is None:
- check_origin_func = partial(_check_origin, allowed_origins=allowed_origins or [])
- else:
- check_origin_func = lambda origin, handler: _is_same_site(origin, handler) or check_origin(origin)
- return _webio_handler(applications=applications,
- check_origin_func=check_origin_func,
- websocket_settings=websocket_settings)
- def static_routes(static_path):
- """获取用于提供PyWebIO静态文件的aiohttp路由"""
- async def index(request):
- return web.FileResponse(path.join(STATIC_PATH, 'index.html'))
- files = [path.join(static_path, d) for d in listdir(static_path)]
- dirs = filter(path.isdir, files)
- routes = [web.static('/' + path.basename(d), d) for d in dirs]
- routes.append(web.get('/', index))
- return routes
- def start_server(applications, port=0, host='', debug=False,
- allowed_origins=None, check_origin=None,
- auto_open_webbrowser=False,
- websocket_settings=None,
- **aiohttp_settings):
- """启动一个 aiohttp server 将 ``target`` 任务函数作为Web服务提供。
- :param list/dict/callable applications: PyWebIO应用. 可以是任务函数或者任务函数的字典或列表。
- :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
- :param int port: server bind port. set ``0`` to find a free port number to use
- :param str host: server bind host. ``host`` may be either an IP address or hostname. If it's a hostname,
- the server will listen on all IP addresses associated with the name.
- set empty string or to listen on all available interfaces.
- :param bool debug: asyncio Debug Mode
- :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
- 来源包含协议和域名和端口部分,允许使用 Unix shell 风格的匹配模式:
- - ``*`` 为通配符
- - ``?`` 匹配单个字符
- - ``[seq]`` 匹配seq内的字符
- - ``[!seq]`` 匹配不在seq内的字符
- 比如 ``https://*.example.com`` 、 ``*://*.example.com``
- :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
- 返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
- :param bool auto_open_webbrowser: Whether or not auto open web browser when server is started (if the operating system allows it) .
- :param dict websocket_settings: 创建 aiohttp WebSocketResponse 时使用的参数。见 https://docs.aiohttp.org/en/stable/web_reference.html#websocketresponse
- :param aiohttp_settings: 需要传给 aiohttp Application 的参数。可用参数见 https://docs.aiohttp.org/en/stable/web_reference.html#application
- """
- kwargs = locals()
- if not host:
- host = '0.0.0.0'
- if port == 0:
- port = get_free_port()
- handler = webio_handler(applications, allowed_origins=allowed_origins, check_origin=check_origin,
- websocket_settings=websocket_settings)
- app = web.Application(**aiohttp_settings)
- app.add_routes([web.get('/io', handler)])
- app.add_routes(static_routes(STATIC_PATH))
- if auto_open_webbrowser:
- asyncio.get_event_loop().create_task(open_webbrowser_on_server_started('localhost', port))
- if debug:
- logging.getLogger("asyncio").setLevel(logging.DEBUG)
- print('Listen on %s:%s' % (host, port))
- web.run_app(app, host=host, port=port)
|