Parcourir la source

maint: refine Session base class && add session.data()

wangweimin il y a 5 ans
Parent
commit
17dc583bb6

+ 2 - 2
pywebio/platform/aiohttp.py

@@ -9,7 +9,7 @@ from urllib.parse import urlparse
 from aiohttp import web
 
 from .tornado import open_webbrowser_on_server_started
-from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, AbstractSession
+from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, Session
 from ..session.base import get_session_info_from_headers
 from ..utils import get_free_port, STATIC_PATH
 
@@ -56,7 +56,7 @@ def _webio_handler(target, session_cls, websocket_settings, check_origin_func=_i
 
         close_from_session_tag = False  # 是否由session主动关闭连接
 
-        def send_msg_to_client(session: AbstractSession):
+        def send_msg_to_client(session: Session):
             for msg in session.get_task_commands():
                 msg_str = json.dumps(msg)
                 ioloop.create_task(ws.send_str(msg_str))

+ 2 - 2
pywebio/platform/httpbased.py

@@ -20,7 +20,7 @@ import threading
 from typing import Dict
 
 import time
-from ..session import CoroutineBasedSession, AbstractSession, register_session_implement_for_target
+from ..session import CoroutineBasedSession, Session, register_session_implement_for_target
 from ..session.base import get_session_info_from_headers
 from ..utils import random_str, LRUDict
 
@@ -81,7 +81,7 @@ _event_loop = None
 
 # todo: use lock to avoid thread race condition
 class HttpHandler:
-    # type: Dict[str, AbstractSession]
+    # type: Dict[str, Session]
     _webio_sessions = {}  # WebIOSessionID -> WebIOSession()
     _webio_expire = LRUDict()  # WebIOSessionID -> last active timestamp。按照最后活跃时间递增排列
 

+ 2 - 2
pywebio/platform/tornado.py

@@ -15,7 +15,7 @@ from tornado.web import StaticFileHandler
 from tornado.websocket import WebSocketHandler
 
 from ..session import CoroutineBasedSession, ThreadBasedSession, ScriptModeSession, \
-    register_session_implement_for_target, AbstractSession
+    register_session_implement_for_target, Session
 from ..session.base import get_session_info_from_headers
 from ..utils import get_free_port, wait_host_port, STATIC_PATH
 
@@ -69,7 +69,7 @@ def _webio_handler(target, session_cls, check_origin_func=_is_same_site):
             # Non-None enables compression with default options.
             return {}
 
-        def send_msg_to_client(self, session: AbstractSession):
+        def send_msg_to_client(self, session: Session):
             for msg in session.get_task_commands():
                 self.write_message(json.dumps(msg))
 

+ 9 - 2
pywebio/session/__init__.py

@@ -14,7 +14,7 @@ r"""
 import threading
 from functools import wraps
 
-from .base import AbstractSession
+from .base import Session
 from .coroutinebased import CoroutineBasedSession
 from .threadbased import ThreadBasedSession, ScriptModeSession
 from ..exceptions import SessionNotFoundException
@@ -68,7 +68,7 @@ def _start_script_mode_server():
     start_server_in_current_thread_session()
 
 
-def get_current_session() -> "AbstractSession":
+def get_current_session() -> "Session":
     return get_session_implement().get_current_session()
 
 
@@ -179,6 +179,13 @@ def defer_call(func):
     return func
 
 
+def data():
+    """获取当前会话的数据对象,用于在对象上保存一些会话相关的数据。访问数据对象不存在的属性时会返回None而不是抛出异常。
+
+    """
+    return get_current_session().save
+
+
 def get_info():
     """ 获取当前会话的相关信息
 

+ 29 - 22
pywebio/session/base.py

@@ -1,13 +1,21 @@
+import logging
+
 import user_agents
-from ..utils import ObjectDict
 
+from ..utils import ObjectDict, Setter, catch_exp_call
+
+logger = logging.getLogger(__name__)
 
-class AbstractSession:
+
+class Session:
     """
     会话对象,由Backend创建
 
     属性:
         info 表示会话信息的对象
+        save 会话的数据对象,提供用户在对象上保存一些会话相关数据
+
+        _save 用于内部实现的一些状态保存
 
     由Task在当前Session上下文中调用:
         get_current_session
@@ -27,39 +35,30 @@ class AbstractSession:
 
     Task和Backend都可调用:
         closed
-        active_session_count
-
 
     Session是不同的后端Backend与协程交互的桥梁:
         后端Backend在接收到用户浏览器的数据后,会通过调用 ``send_client_event`` 来通知会话,进而由Session驱动协程的运行。
         Task内在调用输入输出函数后,会调用 ``send_task_command`` 向会话发送输入输出消息指令, Session将其保存并留给后端Backend处理。
     """
-    info = object()
-
-    @staticmethod
-    def active_session_count() -> int:
-        raise NotImplementedError
 
     @staticmethod
-    def get_current_session() -> "AbstractSession":
+    def get_current_session() -> "Session":
         raise NotImplementedError
 
     @staticmethod
     def get_current_task_id():
         raise NotImplementedError
 
-    def __init__(self, target, session_info, on_task_command=None, on_session_close=None, **kwargs):
+    def __init__(self, session_info):
         """
-        :param target:
         :param session_info: 会话信息。可以通过 Session.info 访问
-        :param on_task_command: Backend向ession注册的处理函数,当 Session 收到task发送的command时调用
-        :param on_session_close: Backend向Session注册的处理函数,当 Session task 执行结束时调用 *
-        :param kwargs:
-
-        .. note::
-            后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
         """
-        raise NotImplementedError
+        self.info = session_info
+        self.save = Setter()
+        self._save = Setter()
+
+        self.deferred_functions = []  # 会话结束时运行的函数
+        self._closed = False
 
     def send_task_command(self, command):
         raise NotImplementedError
@@ -75,10 +74,17 @@ class AbstractSession:
         raise NotImplementedError
 
     def close(self):
-        raise NotImplementedError
+        if self._closed:
+            return
+        self._closed = True
+
+        self.deferred_functions.reverse()
+        while self.deferred_functions:
+            func = self.deferred_functions.pop()
+            catch_exp_call(func, logger)
 
     def closed(self) -> bool:
-        raise NotImplementedError
+        return self._closed
 
     def on_task_exception(self):
         raise NotImplementedError
@@ -97,7 +103,8 @@ class AbstractSession:
 
         :param func: 话结束时调用的函数
         """
-        raise NotImplementedError
+        """设置会话结束时调用的函数。可以用于资源清理。"""
+        self.deferred_functions.append(func)
 
 
 def get_session_info_from_headers(headers):

+ 8 - 28
pywebio/session/coroutinebased.py

@@ -6,9 +6,9 @@ import traceback
 from contextlib import contextmanager
 from functools import partial
 
-from .base import AbstractSession
+from .base import Session
 from ..exceptions import SessionNotFoundException, SessionClosedException, SessionException
-from ..utils import random_str, isgeneratorfunction, iscoroutinefunction, catch_exp_call
+from ..utils import random_str, isgeneratorfunction, iscoroutinefunction
 
 logger = logging.getLogger(__name__)
 
@@ -29,7 +29,7 @@ class _context:
     current_task_id = None
 
 
-class CoroutineBasedSession(AbstractSession):
+class CoroutineBasedSession(Session):
     """
     基于协程的任务会话
 
@@ -44,12 +44,6 @@ class CoroutineBasedSession(AbstractSession):
     # Flask backend时,在platform.flaskrun_event_loop()时初始化
     event_loop_thread_id = None
 
-    _active_session_cnt = 0
-
-    @classmethod
-    def active_session_count(cls):
-        return cls._active_session_cnt
-
     @classmethod
     def get_current_session(cls) -> "CoroutineBasedSession":
         if _context.current_session is None or cls.event_loop_thread_id != threading.current_thread().ident:
@@ -75,16 +69,13 @@ class CoroutineBasedSession(AbstractSession):
         assert iscoroutinefunction(target) or isgeneratorfunction(target), ValueError(
             "CoroutineBasedSession accept coroutine function or generator function as task function")
 
+        super().__init__(session_info)
+
         cls = type(self)
-        cls._active_session_cnt += 1
 
-        self.info = session_info
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_session_close = on_session_close or (lambda: None)
 
-        # 会话结束时运行的函数
-        self.deferred_functions = []
-
         # 当前会话未被Backend处理的消息
         self.unhandled_task_msgs = []
 
@@ -158,22 +149,15 @@ class CoroutineBasedSession(AbstractSession):
             t.step(SessionClosedException, throw_exp=True)
             t.close()
         self.coros = {}  # delete session tasks
-        type(self)._active_session_cnt -= 1
 
     def close(self):
         """关闭当前Session。由Backend调用"""
-        if self._closed:
+        if self.closed():
             return
-        self._closed = True
-        self._cleanup()
 
-        self.deferred_functions.reverse()
-        while self.deferred_functions:
-            func = self.deferred_functions.pop()
-            catch_exp_call(func, logger)
+        super().close()
 
-    def closed(self):
-        return self._closed
+        self._cleanup()
 
     def on_task_exception(self):
         from ..output import put_markdown  # todo
@@ -252,10 +236,6 @@ class CoroutineBasedSession(AbstractSession):
         res = await WebIOFuture(coro=coro_obj)
         return res
 
-    def defer_call(self, func):
-        """设置会话结束时调用的函数。可以用于资源清理。"""
-        self.deferred_functions.append(func)
-
 
 class TaskHandle:
     """协程任务句柄"""

+ 7 - 30
pywebio/session/threadbased.py

@@ -5,9 +5,9 @@ import threading
 import traceback
 from functools import wraps
 
-from .base import AbstractSession
+from .base import Session
 from ..exceptions import SessionNotFoundException, SessionClosedException, SessionException
-from ..utils import random_str, LimitedSizeQueue, isgeneratorfunction, iscoroutinefunction, catch_exp_call, \
+from ..utils import random_str, LimitedSizeQueue, isgeneratorfunction, iscoroutinefunction, \
     get_function_name
 
 logger = logging.getLogger(__name__)
@@ -22,19 +22,13 @@ logger = logging.getLogger(__name__)
 
 
 # todo 线程安全
-class ThreadBasedSession(AbstractSession):
+class ThreadBasedSession(Session):
     thread2session = {}  # thread_id -> session
 
     unhandled_task_mq_maxsize = 1000
     event_mq_maxsize = 100
     callback_mq_maxsize = 100
 
-    _active_session_cnt = 0
-
-    @classmethod
-    def active_session_count(cls):
-        return cls._active_session_cnt
-
     @classmethod
     def get_current_session(cls) -> "ThreadBasedSession":
         curr = id(threading.current_thread())
@@ -67,16 +61,12 @@ class ThreadBasedSession(AbstractSession):
             "ThreadBasedSession only accept a simple function as task function, "
             "not coroutine function or generator function. ")
 
-        type(self)._active_session_cnt += 1
+        super().__init__(session_info)
 
-        self.info = session_info
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_session_close = on_session_close or (lambda: None)
         self._loop = loop
 
-        # 会话结束时运行的函数
-        self.deferred_functions = []
-
         self.threads = []  # 注册到当前会话的线程集合
         self.unhandled_task_msgs = LimitedSizeQueue(maxsize=self.unhandled_task_mq_maxsize)
 
@@ -188,24 +178,15 @@ class ThreadBasedSession(AbstractSession):
 
         self.task_mqs = {}
 
-        cls._active_session_cnt -= 1
-
     def close(self):
         """关闭当前Session。由Backend调用"""
         # todo self._closed 会有竞争条件
-        if self._closed:
+        if self.closed():
             return
-        self._closed = True
-
-        self._cleanup()
 
-        self.deferred_functions.reverse()
-        while self.deferred_functions:
-            func = self.deferred_functions.pop()
-            catch_exp_call(func, logger)
+        super().close()
 
-    def closed(self):
-        return self._closed
+        self._cleanup()
 
     def on_task_exception(self):
         from ..output import put_markdown  # todo
@@ -294,10 +275,6 @@ class ThreadBasedSession(AbstractSession):
         event_mq = queue.Queue(maxsize=self.event_mq_maxsize)  # 线程内的用户事件队列
         self.task_mqs[self._get_task_id(t)] = event_mq
 
-    def defer_call(self, func):
-        """设置会话结束时调用的函数。可以用于资源清理。"""
-        self.deferred_functions.append(func)
-
 
 class ScriptModeSession(ThreadBasedSession):
     """Script mode的会话实现"""

+ 13 - 0
pywebio/utils.py

@@ -16,6 +16,19 @@ project_dir = dirname(abspath(__file__))
 STATIC_PATH = '%s/html' % project_dir
 
 
+class Setter:
+    """
+    可以在对象属性上保存数据。
+    访问数据对象不存在的属性时会返回None而不是抛出异常。
+    """
+
+    def __getattribute__(self, name):
+        try:
+            return super().__getattribute__(name)
+        except AttributeError:
+            return None
+
+
 class ObjectDict(dict):
     """
     Object like dict, every dict[key] can visite by dict.key