Переглянути джерело

add `session_type` parameter to tornado's `webio_handler ` flask's `webio_view`

wangweimin 5 роки тому
батько
коміт
a7e45b5225
3 змінених файлів з 26 додано та 36 видалено
  1. 12 10
      pywebio/platform/flask.py
  2. 11 13
      pywebio/platform/tornado.py
  3. 3 13
      pywebio/session/__init__.py

+ 12 - 10
pywebio/platform/flask.py

@@ -28,7 +28,7 @@ from typing import Dict
 from flask import Flask, request, jsonify, send_from_directory
 from flask import Flask, request, jsonify, send_from_directory
 
 
 from ..session import CoroutineBasedSession, get_session_implement, AbstractSession, \
 from ..session import CoroutineBasedSession, get_session_implement, AbstractSession, \
-    mark_server_started, get_session_implement_for_target
+    set_session_implement, get_session_implement_for_target
 from ..utils import STATIC_PATH
 from ..utils import STATIC_PATH
 from ..utils import random_str, LRUDict
 from ..utils import random_str, LRUDict
 
 
@@ -65,7 +65,7 @@ def _remove_webio_session(sid):
     del _webio_expire[sid]
     del _webio_expire[sid]
 
 
 
 
-def _webio_view(coro_func, session_expire_seconds):
+def _webio_view(target, session_expire_seconds):
     """
     """
     :param coro_func:
     :param coro_func:
     :param session_expire_seconds:
     :param session_expire_seconds:
@@ -84,7 +84,7 @@ def _webio_view(coro_func, session_expire_seconds):
         set_header = True
         set_header = True
         webio_session_id = random_str(24)
         webio_session_id = random_str(24)
         Session = get_session_implement()
         Session = get_session_implement()
-        webio_session = Session(coro_func)
+        webio_session = Session(target)
         _webio_sessions[webio_session_id] = webio_session
         _webio_sessions[webio_session_id] = webio_session
         _webio_expire[webio_session_id] = time.time()
         _webio_expire[webio_session_id] = time.time()
     elif request.headers['webio-session-id'] not in _webio_sessions:  # WebIOSession deleted
     elif request.headers['webio-session-id'] not in _webio_sessions:  # WebIOSession deleted
@@ -112,9 +112,15 @@ def _webio_view(coro_func, session_expire_seconds):
     return response
     return response
 
 
 
 
-def webio_view(coro_func, session_expire_seconds):
+def webio_view(target, session_expire_seconds, session_type=None):
     """获取Flask view"""
     """获取Flask view"""
-    view_func = partial(_webio_view, coro_func=coro_func, session_expire_seconds=session_expire_seconds)
+
+    if not session_type:
+        session_type = get_session_implement_for_target(target)
+
+    set_session_implement(session_type)
+
+    view_func = partial(_webio_view, target=target, session_expire_seconds=session_expire_seconds)
     view_func.__name__ = 'webio_view'
     view_func.__name__ = 'webio_view'
     return view_func
     return view_func
 
 
@@ -148,13 +154,9 @@ def start_server(target, port=8080, host='localhost',
         ref: https://www.tornadoweb.org/en/stable/web.html#tornado.web.Application.settings
         ref: https://www.tornadoweb.org/en/stable/web.html#tornado.web.Application.settings
     :return:
     :return:
     """
     """
-    if not session_type:
-        session_type = get_session_implement_for_target(target)
-
-    mark_server_started(session_type)
 
 
     app = Flask(__name__)
     app = Flask(__name__)
-    app.route('/io', methods=['GET', 'POST'])(webio_view(target, session_expire_seconds))
+    app.route('/io', methods=['GET', 'POST'])(webio_view(target, session_expire_seconds, session_type=session_type))
 
 
     @app.route('/')
     @app.route('/')
     @app.route('/<path:static_file>')
     @app.route('/<path:static_file>')

+ 11 - 13
pywebio/platform/tornado.py

@@ -10,13 +10,18 @@ import tornado.ioloop
 import tornado.websocket
 import tornado.websocket
 from tornado.web import StaticFileHandler
 from tornado.web import StaticFileHandler
 from ..session import CoroutineBasedSession, ThreadBasedSession, get_session_implement, ScriptModeSession, \
 from ..session import CoroutineBasedSession, ThreadBasedSession, get_session_implement, ScriptModeSession, \
-    mark_server_started, get_session_implement_for_target
+    set_session_implement, get_session_implement_for_target, SCRIPT_MODE
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-def webio_handler(task_func):
+def webio_handler(target, session_type=None):
+    if not session_type:
+        session_type = get_session_implement_for_target(target)
+
+    set_session_implement(session_type)
+
     class WSHandler(tornado.websocket.WebSocketHandler):
     class WSHandler(tornado.websocket.WebSocketHandler):
 
 
         def check_origin(self, origin):
         def check_origin(self, origin):
@@ -37,10 +42,10 @@ def webio_handler(task_func):
             self._close_from_session_tag = False  # 由session主动关闭连接
             self._close_from_session_tag = False  # 由session主动关闭连接
 
 
             if get_session_implement() is CoroutineBasedSession:
             if get_session_implement() is CoroutineBasedSession:
-                self.session = CoroutineBasedSession(task_func, on_task_command=self.send_msg_to_client,
+                self.session = CoroutineBasedSession(target, on_task_command=self.send_msg_to_client,
                                                      on_session_close=self.close_from_session)
                                                      on_session_close=self.close_from_session)
             else:
             else:
-                self.session = ThreadBasedSession(task_func, on_task_command=self.send_msg_to_client,
+                self.session = ThreadBasedSession(target, on_task_command=self.send_msg_to_client,
                                                   on_session_close=self.close_from_session,
                                                   on_session_close=self.close_from_session,
                                                   loop=asyncio.get_event_loop())
                                                   loop=asyncio.get_event_loop())
 
 
@@ -117,17 +122,12 @@ def start_server(target, port=0, host='', debug=False,
     """
     """
     kwargs = locals()
     kwargs = locals()
 
 
-    if not session_type:
-        session_type = get_session_implement_for_target(target)
-
-    mark_server_started(session_type)
-
     app_options = ['debug', 'websocket_max_message_size', 'websocket_ping_interval', 'websocket_ping_timeout']
     app_options = ['debug', 'websocket_max_message_size', 'websocket_ping_interval', 'websocket_ping_timeout']
     for opt in app_options:
     for opt in app_options:
         if kwargs[opt] is not None:
         if kwargs[opt] is not None:
             tornado_app_settings[opt] = kwargs[opt]
             tornado_app_settings[opt] = kwargs[opt]
 
 
-    handler = webio_handler(target)
+    handler = webio_handler(target, session_type=session_type)
     _, port = _setup_server(webio_handler=handler, port=port, host=host, **tornado_app_settings)
     _, port = _setup_server(webio_handler=handler, port=port, host=host, **tornado_app_settings)
     if auto_open_webbrowser:
     if auto_open_webbrowser:
         tornado.ioloop.IOLoop.current().spawn_callback(open_webbrowser_on_server_started, host or 'localhost', port)
         tornado.ioloop.IOLoop.current().spawn_callback(open_webbrowser_on_server_started, host or 'localhost', port)
@@ -136,12 +136,10 @@ def start_server(target, port=0, host='', debug=False,
 
 
 def start_server_in_current_thread_session():
 def start_server_in_current_thread_session():
     """启动 script mode 的server"""
     """启动 script mode 的server"""
-    mark_server_started()
-
     websocket_conn_opened = threading.Event()
     websocket_conn_opened = threading.Event()
     thread = threading.current_thread()
     thread = threading.current_thread()
 
 
-    class SingleSessionWSHandler(webio_handler(None)):
+    class SingleSessionWSHandler(webio_handler(None, session_type=SCRIPT_MODE)):
         session = None
         session = None
 
 
         def open(self):
         def open(self):

+ 3 - 13
pywebio/session/__init__.py

@@ -10,22 +10,12 @@ from ..exceptions import SessionNotFoundException
 
 
 THREAD_BASED = 'ThreadBased'
 THREAD_BASED = 'ThreadBased'
 COROUTINE_BASED = 'CoroutineBased'
 COROUTINE_BASED = 'CoroutineBased'
+SCRIPT_MODE = 'ScriptMode'
 
 
 _session_type = ThreadBasedSession
 _session_type = ThreadBasedSession
 
 
 __all__ = ['run_async', 'run_asyncio_coroutine', 'register_thread', 'THREAD_BASED', 'COROUTINE_BASED']
 __all__ = ['run_async', 'run_asyncio_coroutine', 'register_thread', 'THREAD_BASED', 'COROUTINE_BASED']
 
 
-_server_started = False
-
-
-def mark_server_started(session_type_name=None):
-    """标记服务端已经启动. 仅用于PyWebIO内部使用"""
-    global _server_started
-    _server_started = True
-
-    if session_type_name is not None:
-        _set_session_implement(session_type_name)
-
 
 
 def get_session_implement_for_target(target_func):
 def get_session_implement_for_target(target_func):
     """根据target_func函数类型获取默认会话实现"""
     """根据target_func函数类型获取默认会话实现"""
@@ -34,10 +24,10 @@ def get_session_implement_for_target(target_func):
     return THREAD_BASED
     return THREAD_BASED
 
 
 
 
-def _set_session_implement(session_type_name):
+def set_session_implement(session_type_name):
     """设置会话实现类. 仅用于PyWebIO内部使用"""
     """设置会话实现类. 仅用于PyWebIO内部使用"""
     global _session_type
     global _session_type
-    sessions = {THREAD_BASED: ThreadBasedSession, COROUTINE_BASED: CoroutineBasedSession}
+    sessions = {THREAD_BASED: ThreadBasedSession, COROUTINE_BASED: CoroutineBasedSession, SCRIPT_MODE: ScriptModeSession}
     assert session_type_name in sessions, ValueError('No "%s" Session type ' % session_type_name)
     assert session_type_name in sessions, ValueError('No "%s" Session type ' % session_type_name)
     _session_type = sessions[session_type_name]
     _session_type = sessions[session_type_name]