Bläddra i källkod

support multiple session type in one app

wangweimin 5 år sedan
förälder
incheckning
2ea0b09204

+ 7 - 6
pywebio/platform/flask.py

@@ -29,7 +29,7 @@ from typing import Dict
 from flask import Flask, request, jsonify, send_from_directory, Response
 from flask import Flask, request, jsonify, send_from_directory, Response
 
 
 from ..session import CoroutineBasedSession, get_session_implement, AbstractSession, \
 from ..session import CoroutineBasedSession, get_session_implement, AbstractSession, \
-    set_session_implement_for_target
+    register_session_implement_for_target
 from ..utils import STATIC_PATH
 from ..utils import STATIC_PATH
 from ..utils import random_str, LRUDict
 from ..utils import random_str, LRUDict
 
 
@@ -80,7 +80,7 @@ def cors_headers(origin, check_origin, headers=None):
     return headers
     return headers
 
 
 
 
-def _webio_view(target, session_expire_seconds, check_origin):
+def _webio_view(target, session_cls, session_expire_seconds, check_origin):
     """
     """
     :param target:
     :param target:
     :param session_expire_seconds:
     :param session_expire_seconds:
@@ -103,11 +103,12 @@ def _webio_view(target, session_expire_seconds, check_origin):
         return Response('ok', headers=headers)
         return Response('ok', headers=headers)
 
 
     webio_session_id = None
     webio_session_id = None
+
+    # webio-session-id 的请求头为空时,创建新 Session
     if 'webio-session-id' not in request.headers or not request.headers['webio-session-id']:  # start new WebIOSession
     if 'webio-session-id' not in request.headers or not request.headers['webio-session-id']:  # start new WebIOSession
         webio_session_id = random_str(24)
         webio_session_id = random_str(24)
         headers['webio-session-id'] = webio_session_id
         headers['webio-session-id'] = webio_session_id
-        Session = get_session_implement()
-        webio_session = Session(target)
+        webio_session = session_cls(target)
         _webio_sessions[webio_session_id] = webio_session
         _webio_sessions[webio_session_id] = webio_session
         _webio_expire[webio_session_id] = time.time()
         _webio_expire[webio_session_id] = time.time()
     elif request.headers['webio-session-id'] not in _webio_sessions:  # WebIOSession deleted
     elif request.headers['webio-session-id'] not in _webio_sessions:  # WebIOSession deleted
@@ -152,7 +153,7 @@ def webio_view(target, session_expire_seconds=DEFAULT_SESSION_EXPIRE_SECONDS, al
     :return: Flask视图函数
     :return: Flask视图函数
     """
     """
 
 
-    set_session_implement_for_target(target)
+    session_cls = register_session_implement_for_target(target)
 
 
     if check_origin is None:
     if check_origin is None:
         check_origin = lambda origin: any(
         check_origin = lambda origin: any(
@@ -160,7 +161,7 @@ def webio_view(target, session_expire_seconds=DEFAULT_SESSION_EXPIRE_SECONDS, al
             for patten in allowed_origins
             for patten in allowed_origins
         )
         )
 
 
-    view_func = partial(_webio_view, target=target,
+    view_func = partial(_webio_view, target=target, session_cls=session_cls,
                         session_expire_seconds=session_expire_seconds,
                         session_expire_seconds=session_expire_seconds,
                         check_origin=check_origin)
                         check_origin=check_origin)
     view_func.__name__ = 'webio_view'
     view_func.__name__ = 'webio_view'

+ 12 - 9
pywebio/platform/tornado.py

@@ -13,8 +13,8 @@ import tornado.ioloop
 import tornado.websocket
 import tornado.websocket
 from tornado.web import StaticFileHandler
 from tornado.web import StaticFileHandler
 from tornado.websocket import WebSocketHandler
 from tornado.websocket import WebSocketHandler
-from ..session import CoroutineBasedSession, ThreadBasedSession, get_session_implement, ScriptModeSession, \
-    set_session_implement_for_target, AbstractSession
+from ..session import CoroutineBasedSession, ThreadBasedSession, ScriptModeSession, \
+    register_session_implement_for_target, AbstractSession
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -41,13 +41,14 @@ def _is_same_site(origin, handler: WebSocketHandler):
     return origin == host
     return origin == host
 
 
 
 
-def _webio_handler(target, check_origin_func=_is_same_site):
+def _webio_handler(target, session_cls, check_origin_func=_is_same_site):
     """获取用于Tornado进行整合的RequestHandle类
     """获取用于Tornado进行整合的RequestHandle类
 
 
     :param target: 任务函数
     :param target: 任务函数
     :param callable check_origin_func: check_origin_func(origin, handler) -> bool
     :param callable check_origin_func: check_origin_func(origin, handler) -> bool
     :return: Tornado RequestHandle类
     :return: Tornado RequestHandle类
     """
     """
+
     class WSHandler(WebSocketHandler):
     class WSHandler(WebSocketHandler):
 
 
         def check_origin(self, origin):
         def check_origin(self, origin):
@@ -67,13 +68,15 @@ def _webio_handler(target, check_origin_func=_is_same_site):
 
 
             self._close_from_session_tag = False  # 由session主动关闭连接
             self._close_from_session_tag = False  # 由session主动关闭连接
 
 
-            if get_session_implement() is CoroutineBasedSession:
+            if session_cls is CoroutineBasedSession:
                 self.session = CoroutineBasedSession(target, on_task_command=self.send_msg_to_client,
                 self.session = CoroutineBasedSession(target, on_task_command=self.send_msg_to_client,
                                                      on_session_close=self.close_from_session)
                                                      on_session_close=self.close_from_session)
-            else:
+            elif session_cls is ThreadBasedSession:
                 self.session = ThreadBasedSession(target, on_task_command=self.send_msg_to_client,
                 self.session = ThreadBasedSession(target, on_task_command=self.send_msg_to_client,
                                                   on_session_close=self.close_from_session,
                                                   on_session_close=self.close_from_session,
                                                   loop=asyncio.get_event_loop())
                                                   loop=asyncio.get_event_loop())
+            else:
+                raise RuntimeError("Don't support session type:%s" % session_cls)
 
 
         def on_message(self, message):
         def on_message(self, message):
             data = json.loads(message)
             data = json.loads(message)
@@ -90,6 +93,7 @@ def _webio_handler(target, check_origin_func=_is_same_site):
 
 
     return WSHandler
     return WSHandler
 
 
+
 def webio_handler(target, allowed_origins=None, check_origin=None):
 def webio_handler(target, allowed_origins=None, check_origin=None):
     """获取用于Tornado进行整合的RequestHandle类
     """获取用于Tornado进行整合的RequestHandle类
 
 
@@ -100,7 +104,7 @@ def webio_handler(target, allowed_origins=None, check_origin=None):
         返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
         返回 ``True/False`` 。若设置了 ``check_origin`` , ``allowed_origins`` 参数将被忽略
     :return: Tornado RequestHandle类
     :return: Tornado RequestHandle类
     """
     """
-    set_session_implement_for_target(target)
+    session_cls = register_session_implement_for_target(target)
 
 
     if check_origin is None:
     if check_origin is None:
         check_origin_func = _is_same_site
         check_origin_func = _is_same_site
@@ -109,8 +113,7 @@ def webio_handler(target, allowed_origins=None, check_origin=None):
     else:
     else:
         check_origin_func = lambda origin, handler: check_origin(origin)
         check_origin_func = lambda origin, handler: check_origin(origin)
 
 
-    return _webio_handler(target=target, check_origin_func=check_origin_func)
-
+    return _webio_handler(target=target, session_cls=session_cls, check_origin_func=check_origin_func)
 
 
 
 
 async def open_webbrowser_on_server_started(host, port):
 async def open_webbrowser_on_server_started(host, port):
@@ -188,7 +191,7 @@ def start_server_in_current_thread_session():
     websocket_conn_opened = threading.Event()
     websocket_conn_opened = threading.Event()
     thread = threading.current_thread()
     thread = threading.current_thread()
 
 
-    class SingleSessionWSHandler(_webio_handler(target=None)):
+    class SingleSessionWSHandler(_webio_handler(target=None, session_cls=None)):
         session = None
         session = None
 
 
         def open(self):
         def open(self):

+ 35 - 16
pywebio/session/__init__.py

@@ -16,26 +16,44 @@ from .coroutinebased import CoroutineBasedSession
 from .threadbased import ThreadBasedSession, ScriptModeSession
 from .threadbased import ThreadBasedSession, ScriptModeSession
 from ..exceptions import SessionNotFoundException
 from ..exceptions import SessionNotFoundException
 
 
-_session_type = None
+# 当前进程中正在使用的会话实现的列表
+_active_session_cls = []
 
 
 __all__ = ['run_async', 'run_asyncio_coroutine', 'register_thread']
 __all__ = ['run_async', 'run_asyncio_coroutine', 'register_thread']
 
 
 
 
-def set_session_implement_for_target(target_func):
-    """根据target_func函数类型设置会话实现"""
-    global _session_type
+def register_session_implement_for_target(target_func):
+    """根据target_func函数类型注册会话实现,并返回会话实现"""
     if asyncio.iscoroutinefunction(target_func) or inspect.isgeneratorfunction(target_func):
     if asyncio.iscoroutinefunction(target_func) or inspect.isgeneratorfunction(target_func):
-        _session_type = CoroutineBasedSession
+        cls = CoroutineBasedSession
     else:
     else:
-        _session_type = ThreadBasedSession
+        cls = ThreadBasedSession
+
+    if cls not in _active_session_cls:
+        _active_session_cls.append(cls)
+
+    return cls
 
 
 
 
 def get_session_implement():
 def get_session_implement():
-    global _session_type
-    if _session_type is None:
-        _session_type = ScriptModeSession
+    """获取当前会话实现。仅供内部实现使用。应在会话上下文中调用"""
+    if not _active_session_cls:
+        _active_session_cls.append(ScriptModeSession)
         _start_script_mode_server()
         _start_script_mode_server()
-    return _session_type
+
+    # 当前正在使用的会话实现只有一个
+    if len(_active_session_cls) == 1:
+        return _active_session_cls[0]
+
+    # 当前有多个正在使用的会话实现
+    for cls in _active_session_cls:
+        try:
+            cls.get_current_session()
+            return cls
+        except SessionNotFoundException:
+            pass
+
+    raise SessionNotFoundException
 
 
 
 
 def _start_script_mode_server():
 def _start_script_mode_server():
@@ -55,16 +73,17 @@ def check_session_impl(session_type):
     def decorator(func):
     def decorator(func):
         @wraps(func)
         @wraps(func)
         def inner(*args, **kwargs):
         def inner(*args, **kwargs):
-            now_impl = get_session_implement()
-            if not issubclass(now_impl,
-                              session_type):  # Check if 'now_impl' is a derived from session_type or is the same class
+            curr_impl = get_session_implement()
+
+            # Check if 'now_impl' is a derived from session_type or is the same class
+            if not issubclass(curr_impl, session_type):
                 func_name = getattr(func, '__name__', str(func))
                 func_name = getattr(func, '__name__', str(func))
                 require = getattr(session_type, '__name__', str(session_type))
                 require = getattr(session_type, '__name__', str(session_type))
-                now = getattr(now_impl, '__name__', str(now_impl))
+                curr = getattr(curr_impl, '__name__', str(curr_impl))
 
 
                 raise RuntimeError("Only can invoke `{func_name:s}` in {require:s} context."
                 raise RuntimeError("Only can invoke `{func_name:s}` in {require:s} context."
-                                   " You are now in {now:s} context".format(func_name=func_name, require=require,
-                                                                            now=now))
+                                   " You are now in {curr:s} context".format(func_name=func_name, require=require,
+                                                                             curr=curr))
             return func(*args, **kwargs)
             return func(*args, **kwargs)
 
 
         return inner
         return inner

+ 16 - 6
pywebio/session/coroutinebased.py

@@ -2,6 +2,7 @@ import asyncio
 import inspect
 import inspect
 import logging
 import logging
 import sys
 import sys
+import threading
 import traceback
 import traceback
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
@@ -44,8 +45,9 @@ class CoroutineBasedSession(AbstractSession):
 
 
     @staticmethod
     @staticmethod
     def get_current_session() -> "CoroutineBasedSession":
     def get_current_session() -> "CoroutineBasedSession":
-        if _context.current_session is None:
-            raise SessionNotFoundException("No current found in context!")
+        if _context.current_session is None or \
+                _context.current_session.session_thread_id != threading.current_thread().ident:
+            raise SessionNotFoundException("No session found in current context!")
         return _context.current_session
         return _context.current_session
 
 
     @staticmethod
     @staticmethod
@@ -67,12 +69,20 @@ class CoroutineBasedSession(AbstractSession):
 
 
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_session_close = on_session_close or (lambda: None)
         self._on_session_close = on_session_close or (lambda: None)
+
+        # 当前会话未被Backend处理的消息
         self.unhandled_task_msgs = []
         self.unhandled_task_msgs = []
 
 
+        # 创建会话的线程id。当前会话只能在本线程中使用
+        self.session_thread_id = threading.current_thread().ident
+
+        # 会话内的协程任务
         self.coros = {}  # coro_task_id -> Task()
         self.coros = {}  # coro_task_id -> Task()
 
 
         self._closed = False
         self._closed = False
-        self._not_closed_coro_cnt = 1  # 当前会话未结束运行的协程数量。当 self._not_closed_coro_cnt == 0 时,会话结束。
+
+        # 当前会话未结束运行(已创建和正在运行的)的协程数量。当 _alive_coro_cnt 变为 0 时,会话结束。
+        self._alive_coro_cnt = 1
 
 
         main_task = Task(target(), session=self, on_coro_stop=self._on_task_finish)
         main_task = Task(target(), session=self, on_coro_stop=self._on_task_finish)
         self.coros[main_task.coro_id] = main_task
         self.coros[main_task.coro_id] = main_task
@@ -83,13 +93,13 @@ class CoroutineBasedSession(AbstractSession):
         task.step(result)
         task.step(result)
 
 
     def _on_task_finish(self, task: "Task"):
     def _on_task_finish(self, task: "Task"):
-        self._not_closed_coro_cnt -= 1
+        self._alive_coro_cnt -= 1
 
 
         if task.coro_id in self.coros:
         if task.coro_id in self.coros:
             logger.debug('del self.coros[%s]', task.coro_id)
             logger.debug('del self.coros[%s]', task.coro_id)
             del self.coros[task.coro_id]
             del self.coros[task.coro_id]
 
 
-        if self._not_closed_coro_cnt <= 0 and not self.closed():
+        if self._alive_coro_cnt <= 0 and not self.closed():
             self.send_task_command(dict(command='close_session'))
             self.send_task_command(dict(command='close_session'))
             self._on_session_close()
             self._on_session_close()
             self.close()
             self.close()
@@ -197,7 +207,7 @@ class CoroutineBasedSession(AbstractSession):
         :param coro_obj: 协程对象
         :param coro_obj: 协程对象
         :return: An instance of  `TaskHandle` is returned, which can be used later to close the task.
         :return: An instance of  `TaskHandle` is returned, which can be used later to close the task.
         """
         """
-        self._not_closed_coro_cnt += 1
+        self._alive_coro_cnt += 1
 
 
         task = Task(coro_obj, session=self, on_coro_stop=self._on_task_finish)
         task = Task(coro_obj, session=self, on_coro_stop=self._on_task_finish)
         self.coros[task.coro_id] = task
         self.coros[task.coro_id] = task