1
0
Эх сурвалжийг харах

auto select session implement when `start_server`

wangweimin 5 жил өмнө
parent
commit
cd0e8f5c41

+ 1 - 2
pywebio/demos/zh/overview.py

@@ -388,5 +388,4 @@ if __name__ == '__main__':
     parser.add_argument('--port', type=int, default=0, help='server bind port')
     args = parser.parse_args()
 
-    start_server(feature_overview, host=args.host, port=args.port, auto_open_webbrowser=True,
-                 session_type=COROUTINE_BASED)
+    start_server(feature_overview, host=args.host, port=args.port, auto_open_webbrowser=True)

+ 6 - 3
pywebio/platform/flask.py

@@ -28,7 +28,7 @@ from typing import Dict
 from flask import Flask, request, jsonify, send_from_directory
 
 from ..session import CoroutineBasedSession, get_session_implement, AbstractSession, \
-    mark_server_started
+    mark_server_started, get_session_implement_for_target
 from ..utils import STATIC_PATH
 from ..utils import random_str, LRUDict
 
@@ -95,7 +95,7 @@ def _webio_view(coro_func, session_expire_seconds):
 
     if request.method == 'POST':  # client push event
         webio_session.send_client_event(request.json)
-        time.sleep(WAIT_MS_ON_POST/1000.0)
+        time.sleep(WAIT_MS_ON_POST / 1000.0)
 
     elif request.method == 'GET':  # client pull messages
         pass
@@ -137,7 +137,7 @@ def start_server(target, port=8080, host='localhost',
         a simple function is use ThreadBasedSession.
     :param port: server bind port. set ``0`` to find a free port number to use
     :param host: server bind host. ``host`` may be either an IP address or hostname.  If it's a hostname,
-    :param str session_type: Session <pywebio.session.AbstractSession>` 的实现,默认为基于线程的会话实现。
+    :param str session_type: 指定 `Session <pywebio.session.AbstractSession>` 的实现。未设置则根据 ``target`` 类型选择合适的实现。
         接受的值为 `pywebio.session.THREAD_BASED` 和 `pywebio.session.COROUTINE_BASED`
     :param disable_asyncio: 禁用 asyncio 函数。仅在当 ``session_type=COROUTINE_BASED`` 时有效。
         在Flask backend中使用asyncio需要单独开启一个线程来运行事件循环,
@@ -148,6 +148,9 @@ def start_server(target, port=8080, host='localhost',
         ref: https://www.tornadoweb.org/en/stable/web.html#tornado.web.Application.settings
     :return:
     """
+    if not session_type:
+        session_type = get_session_implement_for_target(target)
+
     mark_server_started(session_type)
 
     app = Flask(__name__)

+ 7 - 4
pywebio/platform/tornado.py

@@ -10,7 +10,7 @@ import tornado.ioloop
 import tornado.websocket
 from tornado.web import StaticFileHandler
 from ..session import CoroutineBasedSession, ThreadBasedSession, get_session_implement, ScriptModeSession, \
-    mark_server_started
+    mark_server_started, get_session_implement_for_target
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
 
 logger = logging.getLogger(__name__)
@@ -101,7 +101,7 @@ def start_server(target, port=0, host='', debug=False,
         set empty string or to listen on all available interfaces.
     :param bool debug: Tornado debug mode
     :param bool auto_open_webbrowser: Whether or not auto open web browser when server is started.
-    :param str session_type: `Session <pywebio.session.AbstractSession>` 的实现,默认为基于线程的会话实现。
+    :param str session_type: 指定 `Session <pywebio.session.AbstractSession>` 的实现。未设置则根据 ``target`` 类型选择合适的实现。
         接受的值为 `pywebio.session.THREAD_BASED` 和 `pywebio.session.COROUTINE_BASED`
     :param int websocket_max_message_size: Max bytes of a message which Tornado can accept.
         Messages larger than the ``websocket_max_message_size`` (default 10MiB) will not be accepted.
@@ -117,6 +117,9 @@ def start_server(target, port=0, host='', debug=False,
     """
     kwargs = locals()
 
+    if not session_type:
+        session_type = get_session_implement_for_target(target)
+
     mark_server_started(session_type)
 
     app_options = ['debug', 'websocket_max_message_size', 'websocket_ping_interval', 'websocket_ping_timeout']
@@ -144,8 +147,8 @@ def start_server_in_current_thread_session():
         def open(self):
             if SingleSessionWSHandler.session is None:
                 SingleSessionWSHandler.session = ScriptModeSession(thread,
-                                                                         on_task_command=self.send_msg_to_client,
-                                                                         loop=asyncio.get_event_loop())
+                                                                   on_task_command=self.send_msg_to_client,
+                                                                   loop=asyncio.get_event_loop())
                 websocket_conn_opened.set()
             else:
                 self.close()

+ 8 - 1
pywebio/session/__init__.py

@@ -1,4 +1,4 @@
-import threading
+import threading, asyncio, inspect
 from functools import wraps
 
 from .base import AbstractSession
@@ -25,6 +25,13 @@ def mark_server_started(session_type_name=None):
         _set_session_implement(session_type_name)
 
 
+def get_session_implement_for_target(target_func):
+    """根据target_func函数类型获取默认会话实现"""
+    if asyncio.iscoroutinefunction(target_func) or inspect.isgeneratorfunction(target_func):
+        return COROUTINE_BASED
+    return THREAD_BASED
+
+
 def _set_session_implement(session_type_name):
     """设置会话实现类. 仅用于PyWebIO内部使用"""
     global _session_type