Răsfoiți Sursa

auto detect session type instead of using `session_type` parameter

wangweimin 5 ani în urmă
părinte
comite
ae08a45f3d

+ 1 - 0
docs/guide.rst

@@ -329,6 +329,7 @@ PyWebIO默认通过当前页面的同级的 ``./io`` API与后端进行通讯,
    `webio_view() <pywebio.platform.flask.webio_view>` 中使用 ``allowed_origins`` 或 ``check_origin``
    参数来允许后端接收页面所在的host
 
+.. _coroutine_based_session:
 
 基于协程的会话
 ^^^^^^^^^^^^^^

+ 0 - 1
pywebio/__init__.py

@@ -13,7 +13,6 @@ from . import input
 from . import output
 from .session import (
     run_async, run_asyncio_coroutine, register_thread,
-    THREAD_BASED, COROUTINE_BASED
 )
 from .exceptions import SessionException, SessionClosedException, SessionNotFoundException
 from .utils import STATIC_PATH

+ 2 - 3
pywebio/demos/zh/overview.py

@@ -6,7 +6,7 @@ import asyncio
 from datetime import datetime
 from functools import partial
 
-from pywebio import start_server, run_async, COROUTINE_BASED
+from pywebio import start_server, run_async
 from pywebio.input import *
 from pywebio.output import *
 
@@ -388,5 +388,4 @@ if __name__ == '__main__':
     parser.add_argument('--port', type=int, default=0, help='server bind port')
     args = parser.parse_args()
 
-    # from pywebio.platform.flask import start_server
-    start_server(feature_overview, debug=1, host=args.host, port=args.port, allowed_origins=['http://localhost:63342'])
+    start_server(feature_overview, debug=True, auto_open_webbrowser=True, host=args.host, port=args.port, allowed_origins=['http://localhost:63342'])

+ 14 - 11
pywebio/platform/flask.py

@@ -29,7 +29,7 @@ from typing import Dict
 from flask import Flask, request, jsonify, send_from_directory, Response
 
 from ..session import CoroutineBasedSession, get_session_implement, AbstractSession, \
-    set_session_implement, get_session_implement_for_target
+    set_session_implement_for_target
 from ..utils import STATIC_PATH
 from ..utils import random_str, LRUDict
 
@@ -139,13 +139,20 @@ def _webio_view(target, session_expire_seconds, check_origin):
     return response
 
 
-def webio_view(target, session_expire_seconds=DEFAULT_SESSION_EXPIRE_SECONDS, session_type=None, allowed_origins=None, check_origin=None):
-    """获取Flask view"""
+def webio_view(target, session_expire_seconds=DEFAULT_SESSION_EXPIRE_SECONDS, allowed_origins=None, check_origin=None):
+    """获取用于与Flask进行整合的view函数
 
-    if not session_type:
-        session_type = get_session_implement_for_target(target)
+    :param target: 任务函数。任务函数为协程函数时,使用 :ref:`基于协程的会话实现 <coroutine_based_session>` ;任务函数为普通函数时,使用基于线程的会话实现。
+    :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
+    :param session_expire_seconds: 会话不活跃过期时间。
+    :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
+        来源包含协议和域名和端口部分,允许使用 ``*`` 作为通配符。 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
+    :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
+        返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
+    :return: Flask视图函数
+    """
 
-    set_session_implement(session_type)
+    set_session_implement_for_target(target)
 
     if check_origin is None:
         check_origin = lambda origin: any(
@@ -170,7 +177,6 @@ def _setup_event_loop():
 
 def start_server(target, port=8080, host='localhost',
                  allowed_origins=None, check_origin=None,
-                 session_type=None,
                  disable_asyncio=False,
                  session_expire_seconds=DEFAULT_SESSION_EXPIRE_SECONDS,
                  debug=False, **flask_options):
@@ -182,9 +188,7 @@ def start_server(target, port=8080, host='localhost',
     :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
         来源包含协议和域名和端口部分,允许使用 ``*`` 作为通配符。 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
     :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
-        返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽律
-    :param str session_type: 指定 `Session <pywebio.session.AbstractSession>` 的实现。未设置则根据 ``target`` 类型选择合适的实现。
-        接受的值为 `pywebio.session.THREAD_BASED` 和 `pywebio.session.COROUTINE_BASED`
+        返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
     :param disable_asyncio: 禁用 asyncio 函数。仅在当 ``session_type=COROUTINE_BASED`` 时有效。
         在Flask backend中使用asyncio需要单独开启一个线程来运行事件循环,
         若程序中没有使用到asyncio中的异步函数,可以开启此选项来避免不必要的资源浪费
@@ -198,7 +202,6 @@ def start_server(target, port=8080, host='localhost',
     app = Flask(__name__)
     app.route('/io', methods=['GET', 'POST', 'OPTIONS'])(
         webio_view(target, session_expire_seconds,
-                   session_type=session_type,
                    allowed_origins=allowed_origins,
                    check_origin=check_origin)
     )

+ 34 - 22
pywebio/platform/tornado.py

@@ -14,7 +14,7 @@ import tornado.websocket
 from tornado.web import StaticFileHandler
 from tornado.websocket import WebSocketHandler
 from ..session import CoroutineBasedSession, ThreadBasedSession, get_session_implement, ScriptModeSession, \
-    set_session_implement, get_session_implement_for_target, SCRIPT_MODE
+    set_session_implement_for_target
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
 
 logger = logging.getLogger(__name__)
@@ -41,20 +41,13 @@ def _is_same_site(origin, handler: WebSocketHandler):
     return origin == host
 
 
-def webio_handler(target, session_type=None, allowed_origins=None, check_origin=None):
-    if not session_type:
-        session_type = get_session_implement_for_target(target)
-
-    set_session_implement(session_type)
-
-    if check_origin is None:
-        check_origin_func = _is_same_site
-        if allowed_origins:
-            check_origin_func = partial(_check_origin, allowed_origins=allowed_origins)
-    else:
-        check_origin_func = lambda origin, handler: check_origin(origin)
-
+def _webio_handler(target, check_origin_func=_is_same_site):
+    """获取用于Tornado进行整合的RequestHandle类
 
+    :param target: 任务函数
+    :param callable check_origin_func: check_origin_func(origin, handler) -> bool
+    :return: Tornado RequestHandle类
+    """
     class WSHandler(WebSocketHandler):
 
         def check_origin(self, origin):
@@ -97,6 +90,28 @@ def webio_handler(target, session_type=None, allowed_origins=None, check_origin=
 
     return WSHandler
 
+def webio_handler(target, allowed_origins=None, check_origin=None):
+    """获取用于Tornado进行整合的RequestHandle类
+
+    :param target: 任务函数。任务函数为协程函数时,使用 :ref:`基于协程的会话实现 <coroutine_based_session>` ;任务函数为普通函数时,使用基于线程的会话实现。
+    :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
+        来源包含协议和域名和端口部分,允许使用 ``*`` 作为通配符。 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
+    :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
+        返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
+    :return: Tornado RequestHandle类
+    """
+    set_session_implement_for_target(target)
+
+    if check_origin is None:
+        check_origin_func = _is_same_site
+        if allowed_origins:
+            check_origin_func = partial(_check_origin, allowed_origins=allowed_origins)
+    else:
+        check_origin_func = lambda origin, handler: check_origin(origin)
+
+    return _webio_handler(target=target, check_origin_func=check_origin_func)
+
+
 
 async def open_webbrowser_on_server_started(host, port):
     url = 'http://%s:%s' % (host, port)
@@ -125,15 +140,14 @@ def _setup_server(webio_handler, port=0, host='', **tornado_app_settings):
 def start_server(target, port=0, host='', debug=False,
                  allowed_origins=None, check_origin=None,
                  auto_open_webbrowser=False,
-                 session_type=None,
                  websocket_max_message_size=None,
                  websocket_ping_interval=None,
                  websocket_ping_timeout=None,
                  **tornado_app_settings):
     """Start a Tornado server to serve `target` function
 
-    :param target: task function. It's a coroutine function is use CoroutineBasedSession or
-        a simple function is use ThreadBasedSession.
+    :param target: 任务函数。任务函数为协程函数时,使用 :ref:`基于协程的会话实现 <coroutine_based_session>` ;任务函数为普通函数时,使用基于线程的会话实现。
+    :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
     :param port: server bind port. set ``0`` to find a free port number to use
     :param host: server bind host. ``host`` may be either an IP address or hostname.  If it's a hostname,
         the server will listen on all IP addresses associated with the name.
@@ -142,10 +156,8 @@ def start_server(target, port=0, host='', debug=False,
     :param list allowed_origins: 除当前域名外,服务器还允许的请求的来源列表。
         来源包含协议和域名和端口部分,允许使用 ``*`` 作为通配符。 比如 ``https://*.example.com`` 、 ``*://*.example.com`` 、
     :param callable check_origin: 请求来源检查函数。接收请求来源(包含协议和域名和端口部分)字符串,
-        返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽
+        返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽
     :param bool auto_open_webbrowser: Whether or not auto open web browser when server is started.
-    :param str session_type: 指定 `Session <pywebio.session.AbstractSession>` 的实现。未设置则根据 ``target`` 类型选择合适的实现。
-        接受的值为 `pywebio.session.THREAD_BASED` 和 `pywebio.session.COROUTINE_BASED`
     :param int websocket_max_message_size: Max bytes of a message which Tornado can accept.
         Messages larger than the ``websocket_max_message_size`` (default 10MiB) will not be accepted.
     :param int websocket_ping_interval: If set to a number, all websockets will be pinged every n seconds.
@@ -165,7 +177,7 @@ def start_server(target, port=0, host='', debug=False,
         if kwargs[opt] is not None:
             tornado_app_settings[opt] = kwargs[opt]
 
-    handler = webio_handler(target, session_type=session_type, allowed_origins=allowed_origins, check_origin=check_origin)
+    handler = webio_handler(target, allowed_origins=allowed_origins, check_origin=check_origin)
     _, port = _setup_server(webio_handler=handler, port=port, host=host, **tornado_app_settings)
     if auto_open_webbrowser:
         tornado.ioloop.IOLoop.current().spawn_callback(open_webbrowser_on_server_started, host or 'localhost', port)
@@ -177,7 +189,7 @@ def start_server_in_current_thread_session():
     websocket_conn_opened = threading.Event()
     thread = threading.current_thread()
 
-    class SingleSessionWSHandler(webio_handler(None, session_type=SCRIPT_MODE)):
+    class SingleSessionWSHandler(_webio_handler(target=None)):
         session = None
 
         def open(self):

+ 7 - 16
pywebio/session/__init__.py

@@ -16,28 +16,19 @@ from .coroutinebased import CoroutineBasedSession
 from .threadbased import ThreadBasedSession, ScriptModeSession
 from ..exceptions import SessionNotFoundException
 
-THREAD_BASED = 'ThreadBased'
-COROUTINE_BASED = 'CoroutineBased'
-SCRIPT_MODE = 'ScriptMode'
 
 _session_type = ThreadBasedSession
 
-__all__ = ['run_async', 'run_asyncio_coroutine', 'register_thread', 'THREAD_BASED', 'COROUTINE_BASED']
+__all__ = ['run_async', 'run_asyncio_coroutine', 'register_thread']
 
 
-def get_session_implement_for_target(target_func):
-    """根据target_func函数类型获取默认会话实现"""
-    if asyncio.iscoroutinefunction(target_func) or inspect.isgeneratorfunction(target_func):
-        return COROUTINE_BASED
-    return THREAD_BASED
-
-
-def set_session_implement(session_type_name):
-    """设置会话实现类. 仅用于PyWebIO内部使用"""
+def set_session_implement_for_target(target_func):
+    """根据target_func函数类型设置会话实现"""
     global _session_type
-    sessions = {THREAD_BASED: ThreadBasedSession, COROUTINE_BASED: CoroutineBasedSession, SCRIPT_MODE: ScriptModeSession}
-    assert session_type_name in sessions, ValueError('No "%s" Session type ' % session_type_name)
-    _session_type = sessions[session_type_name]
+    if asyncio.iscoroutinefunction(target_func) or inspect.isgeneratorfunction(target_func):
+        _session_type = CoroutineBasedSession
+    else:
+        _session_type = ThreadBasedSession
 
 
 def get_session_implement():