|
@@ -1,14 +1,18 @@
|
|
|
import asyncio
|
|
|
+import fnmatch
|
|
|
import json
|
|
|
import logging
|
|
|
import threading
|
|
|
import webbrowser
|
|
|
+from functools import partial
|
|
|
+from urllib.parse import urlparse
|
|
|
|
|
|
import tornado
|
|
|
import tornado.httpserver
|
|
|
import tornado.ioloop
|
|
|
import tornado.websocket
|
|
|
from tornado.web import StaticFileHandler
|
|
|
+from tornado.websocket import WebSocketHandler
|
|
|
from ..session import CoroutineBasedSession, ThreadBasedSession, get_session_implement, ScriptModeSession, \
|
|
|
set_session_implement, get_session_implement_for_target, SCRIPT_MODE
|
|
|
from ..utils import get_free_port, wait_host_port, STATIC_PATH
|
|
@@ -16,16 +20,45 @@ from ..utils import get_free_port, wait_host_port, STATIC_PATH
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
-def webio_handler(target, session_type=None):
|
|
|
+def _check_origin(origin, allowed_origins, handler: WebSocketHandler):
|
|
|
+ if _is_same_site(origin, handler):
|
|
|
+ return True
|
|
|
+
|
|
|
+ return any(
|
|
|
+ fnmatch.fnmatch(origin, patten)
|
|
|
+ for patten in allowed_origins
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _is_same_site(origin, handler: WebSocketHandler):
|
|
|
+ parsed_origin = urlparse(origin)
|
|
|
+ origin = parsed_origin.netloc
|
|
|
+ origin = origin.lower()
|
|
|
+
|
|
|
+ host = handler.request.headers.get("Host")
|
|
|
+
|
|
|
+ # Check to see that origin matches host directly, including ports
|
|
|
+ return origin == host
|
|
|
+
|
|
|
+
|
|
|
+def webio_handler(target, session_type=None, allowed_origins=None, check_origin=None):
|
|
|
if not session_type:
|
|
|
session_type = get_session_implement_for_target(target)
|
|
|
|
|
|
set_session_implement(session_type)
|
|
|
|
|
|
- class WSHandler(tornado.websocket.WebSocketHandler):
|
|
|
+ if check_origin is None:
|
|
|
+ check_origin_func = _is_same_site
|
|
|
+ if allowed_origins:
|
|
|
+ check_origin_func = partial(_check_origin, allowed_origins=allowed_origins)
|
|
|
+ else:
|
|
|
+ check_origin_func = lambda origin, handler: check_origin(origin)
|
|
|
+
|
|
|
+
|
|
|
+ class WSHandler(WebSocketHandler):
|
|
|
|
|
|
def check_origin(self, origin):
|
|
|
- return True
|
|
|
+ return check_origin_func(origin=origin, handler=self)
|
|
|
|
|
|
def get_compression_options(self):
|
|
|
# Non-None enables compression with default options.
|
|
@@ -90,6 +123,7 @@ def _setup_server(webio_handler, port=0, host='', **tornado_app_settings):
|
|
|
|
|
|
|
|
|
def start_server(target, port=0, host='', debug=False,
|
|
|
+ allowed_origins=None, check_origin=None,
|
|
|
auto_open_webbrowser=False,
|
|
|
session_type=None,
|
|
|
websocket_max_message_size=None,
|
|
@@ -105,6 +139,10 @@ def start_server(target, port=0, host='', debug=False,
|
|
|
the server will listen on all IP addresses associated with the name.
|
|
|
set empty string or to listen on all available interfaces.
|
|
|
:param bool debug: Tornado debug mode
|
|
|
+ :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
|
|
|
+ 来源包含协议和域名和端口部分,允许使用 ``*`` 作为通配符。 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
|
|
|
+ :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
|
|
|
+ 返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽律
|
|
|
:param bool auto_open_webbrowser: Whether or not auto open web browser when server is started.
|
|
|
:param str session_type: 指定 `Session <pywebio.session.AbstractSession>` 的实现。未设置则根据 ``target`` 类型选择合适的实现。
|
|
|
接受的值为 `pywebio.session.THREAD_BASED` 和 `pywebio.session.COROUTINE_BASED`
|
|
@@ -127,7 +165,7 @@ def start_server(target, port=0, host='', debug=False,
|
|
|
if kwargs[opt] is not None:
|
|
|
tornado_app_settings[opt] = kwargs[opt]
|
|
|
|
|
|
- handler = webio_handler(target, session_type=session_type)
|
|
|
+ handler = webio_handler(target, session_type=session_type, allowed_origins=allowed_origins, check_origin=check_origin)
|
|
|
_, port = _setup_server(webio_handler=handler, port=port, host=host, **tornado_app_settings)
|
|
|
if auto_open_webbrowser:
|
|
|
tornado.ioloop.IOLoop.current().spawn_callback(open_webbrowser_on_server_started, host or 'localhost', port)
|
|
@@ -159,6 +197,7 @@ def start_server_in_current_thread_session():
|
|
|
logger.debug('ScriptModeSession closed')
|
|
|
|
|
|
async def wait_to_stop_loop():
|
|
|
+ """当只剩当前线程和Daemon线程运行时,关闭Server"""
|
|
|
alive_none_daemonic_thread_cnt = None
|
|
|
while alive_none_daemonic_thread_cnt != 1:
|
|
|
alive_none_daemonic_thread_cnt = sum(
|