Forráskód Böngészése

add: origin check in tornado backend

wangweimin 5 éve
szülő
commit
05b401b383
1 módosított fájl, 43 hozzáadás és 4 törlés
  1. 43 4
      pywebio/platform/tornado.py

+ 43 - 4
pywebio/platform/tornado.py

@@ -1,14 +1,18 @@
 import asyncio
+import fnmatch
 import json
 import logging
 import threading
 import webbrowser
+from functools import partial
+from urllib.parse import urlparse
 
 import tornado
 import tornado.httpserver
 import tornado.ioloop
 import tornado.websocket
 from tornado.web import StaticFileHandler
+from tornado.websocket import WebSocketHandler
 from ..session import CoroutineBasedSession, ThreadBasedSession, get_session_implement, ScriptModeSession, \
     set_session_implement, get_session_implement_for_target, SCRIPT_MODE
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
@@ -16,16 +20,45 @@ from ..utils import get_free_port, wait_host_port, STATIC_PATH
 logger = logging.getLogger(__name__)
 
 
-def webio_handler(target, session_type=None):
+def _check_origin(origin, allowed_origins, handler: WebSocketHandler):
+    if _is_same_site(origin, handler):
+        return True
+
+    return any(
+        fnmatch.fnmatch(origin, patten)
+        for patten in allowed_origins
+    )
+
+
+def _is_same_site(origin, handler: WebSocketHandler):
+    parsed_origin = urlparse(origin)
+    origin = parsed_origin.netloc
+    origin = origin.lower()
+
+    host = handler.request.headers.get("Host")
+
+    # Check to see that origin matches host directly, including ports
+    return origin == host
+
+
+def webio_handler(target, session_type=None, allowed_origins=None, check_origin=None):
     if not session_type:
         session_type = get_session_implement_for_target(target)
 
     set_session_implement(session_type)
 
-    class WSHandler(tornado.websocket.WebSocketHandler):
+    if check_origin is None:
+        check_origin_func = _is_same_site
+        if allowed_origins:
+            check_origin_func = partial(_check_origin, allowed_origins=allowed_origins)
+    else:
+        check_origin_func = lambda origin, handler: check_origin(origin)
+
+
+    class WSHandler(WebSocketHandler):
 
         def check_origin(self, origin):
-            return True
+            return check_origin_func(origin=origin, handler=self)
 
         def get_compression_options(self):
             # Non-None enables compression with default options.
@@ -90,6 +123,7 @@ def _setup_server(webio_handler, port=0, host='', **tornado_app_settings):
 
 
 def start_server(target, port=0, host='', debug=False,
+                 allowed_origins=None, check_origin=None,
                  auto_open_webbrowser=False,
                  session_type=None,
                  websocket_max_message_size=None,
@@ -105,6 +139,10 @@ def start_server(target, port=0, host='', debug=False,
         the server will listen on all IP addresses associated with the name.
         set empty string or to listen on all available interfaces.
     :param bool debug: Tornado debug mode
+    :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
+        来源包含协议和域名和端口部分,允许使用 ``*`` 作为通配符。 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
+    :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
+        返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽律
     :param bool auto_open_webbrowser: Whether or not auto open web browser when server is started.
     :param str session_type: 指定 `Session <pywebio.session.AbstractSession>` 的实现。未设置则根据 ``target`` 类型选择合适的实现。
         接受的值为 `pywebio.session.THREAD_BASED` 和 `pywebio.session.COROUTINE_BASED`
@@ -127,7 +165,7 @@ def start_server(target, port=0, host='', debug=False,
         if kwargs[opt] is not None:
             tornado_app_settings[opt] = kwargs[opt]
 
-    handler = webio_handler(target, session_type=session_type)
+    handler = webio_handler(target, session_type=session_type, allowed_origins=allowed_origins, check_origin=check_origin)
     _, port = _setup_server(webio_handler=handler, port=port, host=host, **tornado_app_settings)
     if auto_open_webbrowser:
         tornado.ioloop.IOLoop.current().spawn_callback(open_webbrowser_on_server_started, host or 'localhost', port)
@@ -159,6 +197,7 @@ def start_server_in_current_thread_session():
                 logger.debug('ScriptModeSession closed')
 
     async def wait_to_stop_loop():
+        """当只剩当前线程和Daemon线程运行时,关闭Server"""
         alive_none_daemonic_thread_cnt = None
         while alive_none_daemonic_thread_cnt != 1:
             alive_none_daemonic_thread_cnt = sum(