Procházet zdrojové kódy

add: origin check in tornado backend

wangweimin před 5 roky
rodič
revize
05b401b383
1 změnil soubory, kde provedl 43 přidání a 4 odebrání
  1. 43 4
      pywebio/platform/tornado.py

+ 43 - 4
pywebio/platform/tornado.py

@@ -1,14 +1,18 @@
 import asyncio
 import asyncio
+import fnmatch
 import json
 import json
 import logging
 import logging
 import threading
 import threading
 import webbrowser
 import webbrowser
+from functools import partial
+from urllib.parse import urlparse
 
 
 import tornado
 import tornado
 import tornado.httpserver
 import tornado.httpserver
 import tornado.ioloop
 import tornado.ioloop
 import tornado.websocket
 import tornado.websocket
 from tornado.web import StaticFileHandler
 from tornado.web import StaticFileHandler
+from tornado.websocket import WebSocketHandler
 from ..session import CoroutineBasedSession, ThreadBasedSession, get_session_implement, ScriptModeSession, \
 from ..session import CoroutineBasedSession, ThreadBasedSession, get_session_implement, ScriptModeSession, \
     set_session_implement, get_session_implement_for_target, SCRIPT_MODE
     set_session_implement, get_session_implement_for_target, SCRIPT_MODE
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
@@ -16,16 +20,45 @@ from ..utils import get_free_port, wait_host_port, STATIC_PATH
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-def webio_handler(target, session_type=None):
+def _check_origin(origin, allowed_origins, handler: WebSocketHandler):
+    if _is_same_site(origin, handler):
+        return True
+
+    return any(
+        fnmatch.fnmatch(origin, patten)
+        for patten in allowed_origins
+    )
+
+
+def _is_same_site(origin, handler: WebSocketHandler):
+    parsed_origin = urlparse(origin)
+    origin = parsed_origin.netloc
+    origin = origin.lower()
+
+    host = handler.request.headers.get("Host")
+
+    # Check to see that origin matches host directly, including ports
+    return origin == host
+
+
+def webio_handler(target, session_type=None, allowed_origins=None, check_origin=None):
     if not session_type:
     if not session_type:
         session_type = get_session_implement_for_target(target)
         session_type = get_session_implement_for_target(target)
 
 
     set_session_implement(session_type)
     set_session_implement(session_type)
 
 
-    class WSHandler(tornado.websocket.WebSocketHandler):
+    if check_origin is None:
+        check_origin_func = _is_same_site
+        if allowed_origins:
+            check_origin_func = partial(_check_origin, allowed_origins=allowed_origins)
+    else:
+        check_origin_func = lambda origin, handler: check_origin(origin)
+
+
+    class WSHandler(WebSocketHandler):
 
 
         def check_origin(self, origin):
         def check_origin(self, origin):
-            return True
+            return check_origin_func(origin=origin, handler=self)
 
 
         def get_compression_options(self):
         def get_compression_options(self):
             # Non-None enables compression with default options.
             # Non-None enables compression with default options.
@@ -90,6 +123,7 @@ def _setup_server(webio_handler, port=0, host='', **tornado_app_settings):
 
 
 
 
 def start_server(target, port=0, host='', debug=False,
 def start_server(target, port=0, host='', debug=False,
+                 allowed_origins=None, check_origin=None,
                  auto_open_webbrowser=False,
                  auto_open_webbrowser=False,
                  session_type=None,
                  session_type=None,
                  websocket_max_message_size=None,
                  websocket_max_message_size=None,
@@ -105,6 +139,10 @@ def start_server(target, port=0, host='', debug=False,
         the server will listen on all IP addresses associated with the name.
         the server will listen on all IP addresses associated with the name.
         set empty string or to listen on all available interfaces.
         set empty string or to listen on all available interfaces.
     :param bool debug: Tornado debug mode
     :param bool debug: Tornado debug mode
+    :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
+        来源包含协议和域名和端口部分,允许使用 ``*`` 作为通配符。 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
+    :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
+        返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽律
     :param bool auto_open_webbrowser: Whether or not auto open web browser when server is started.
     :param bool auto_open_webbrowser: Whether or not auto open web browser when server is started.
     :param str session_type: 指定 `Session <pywebio.session.AbstractSession>` 的实现。未设置则根据 ``target`` 类型选择合适的实现。
     :param str session_type: 指定 `Session <pywebio.session.AbstractSession>` 的实现。未设置则根据 ``target`` 类型选择合适的实现。
         接受的值为 `pywebio.session.THREAD_BASED` 和 `pywebio.session.COROUTINE_BASED`
         接受的值为 `pywebio.session.THREAD_BASED` 和 `pywebio.session.COROUTINE_BASED`
@@ -127,7 +165,7 @@ def start_server(target, port=0, host='', debug=False,
         if kwargs[opt] is not None:
         if kwargs[opt] is not None:
             tornado_app_settings[opt] = kwargs[opt]
             tornado_app_settings[opt] = kwargs[opt]
 
 
-    handler = webio_handler(target, session_type=session_type)
+    handler = webio_handler(target, session_type=session_type, allowed_origins=allowed_origins, check_origin=check_origin)
     _, port = _setup_server(webio_handler=handler, port=port, host=host, **tornado_app_settings)
     _, port = _setup_server(webio_handler=handler, port=port, host=host, **tornado_app_settings)
     if auto_open_webbrowser:
     if auto_open_webbrowser:
         tornado.ioloop.IOLoop.current().spawn_callback(open_webbrowser_on_server_started, host or 'localhost', port)
         tornado.ioloop.IOLoop.current().spawn_callback(open_webbrowser_on_server_started, host or 'localhost', port)
@@ -159,6 +197,7 @@ def start_server_in_current_thread_session():
                 logger.debug('ScriptModeSession closed')
                 logger.debug('ScriptModeSession closed')
 
 
     async def wait_to_stop_loop():
     async def wait_to_stop_loop():
+        """当只剩当前线程和Daemon线程运行时,关闭Server"""
         alive_none_daemonic_thread_cnt = None
         alive_none_daemonic_thread_cnt = None
         while alive_none_daemonic_thread_cnt != 1:
         while alive_none_daemonic_thread_cnt != 1:
             alive_none_daemonic_thread_cnt = sum(
             alive_none_daemonic_thread_cnt = sum(