Browse Source

update
Session close logic

wangweimin 5 years ago
parent
commit
224ef896de

+ 4 - 4
pywebio/platform/tornado.py

@@ -34,11 +34,11 @@ def webio_handler(task_func):
             logger.debug("WebSocket opened")
             logger.debug("WebSocket opened")
             self.set_nodelay(True)
             self.set_nodelay(True)
 
 
-            self._close_from_session_tag = False  # 是否从session中关闭连接
+            self._close_from_session_tag = False  # 由session主动关闭连接
 
 
             if get_session_implement() is CoroutineBasedSession:
             if get_session_implement() is CoroutineBasedSession:
                 self.session = CoroutineBasedSession(task_func, on_task_command=self.send_msg_to_client,
                 self.session = CoroutineBasedSession(task_func, on_task_command=self.send_msg_to_client,
-                                                     on_session_close=self.close)
+                                                     on_session_close=self.close_from_session)
             else:
             else:
                 self.session = ThreadBasedSession(task_func, on_task_command=self.send_msg_to_client,
                 self.session = ThreadBasedSession(task_func, on_task_command=self.send_msg_to_client,
                                                   on_session_close=self.close_from_session,
                                                   on_session_close=self.close_from_session,
@@ -53,8 +53,8 @@ def webio_handler(task_func):
             self.close()
             self.close()
 
 
         def on_close(self):
         def on_close(self):
-            if not self._close_from_session_tag:
-                self.session.close(no_session_close_callback=True)
+            if not self._close_from_session_tag:  # 只有在由客户端主动断开连接时,才调用 session.close()
+                self.session.close()
             logger.debug("WebSocket closed")
             logger.debug("WebSocket closed")
 
 
     return WSHandler
     return WSHandler

+ 13 - 5
pywebio/session/base.py

@@ -1,5 +1,7 @@
 class AbstractSession:
 class AbstractSession:
     """
     """
+    会话对象,由Backend创建
+
     由Task在当前Session上下文中调用:
     由Task在当前Session上下文中调用:
         get_current_session
         get_current_session
         get_current_task_id
         get_current_task_id
@@ -12,10 +14,10 @@ class AbstractSession:
 
 
     由Backend调用:
     由Backend调用:
         send_client_event
         send_client_event
-        get_task_command
+        get_task_commands
+        close
 
 
     Task和Backend都可调用:
     Task和Backend都可调用:
-        close
         closed
         closed
 
 
     .. note::
     .. note::
@@ -31,24 +33,30 @@ class AbstractSession:
         raise NotImplementedError
         raise NotImplementedError
 
 
     def __init__(self, target, on_task_command=None, on_session_close=None, **kwargs):
     def __init__(self, target, on_task_command=None, on_session_close=None, **kwargs):
+        """
+        :param target:
+        :param on_task_command: Backend向ession注册的处理函数,当 Session 收到task发送的command时调用
+        :param on_session_close: Backend向Session注册的处理函数,当 Session task执行结束时调用 *
+        :param kwargs:
+        """
         raise NotImplementedError
         raise NotImplementedError
 
 
     def send_task_command(self, command):
     def send_task_command(self, command):
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def next_client_event(self):
+    def next_client_event(self) -> dict:
         raise NotImplementedError
         raise NotImplementedError
 
 
     def send_client_event(self, event):
     def send_client_event(self, event):
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def get_task_commands(self):
+    def get_task_commands(self) -> list:
         raise NotImplementedError
         raise NotImplementedError
 
 
     def close(self):
     def close(self):
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def closed(self):
+    def closed(self) -> bool:
         raise NotImplementedError
         raise NotImplementedError
 
 
     def on_task_exception(self):
     def on_task_exception(self):

+ 4 - 9
pywebio/session/coroutinebased.py

@@ -59,7 +59,7 @@ class CoroutineBasedSession(AbstractSession):
         :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
         :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
         """
         """
         assert asyncio.iscoroutinefunction(target) or inspect.isgeneratorfunction(target), ValueError(
         assert asyncio.iscoroutinefunction(target) or inspect.isgeneratorfunction(target), ValueError(
-            "In CoroutineBasedSession accept coroutine function or generator function as task function")
+            "CoroutineBasedSession accept coroutine function or generator function as task function")
 
 
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_session_close = on_session_close or (lambda: None)
         self._on_session_close = on_session_close or (lambda: None)
@@ -93,6 +93,7 @@ class CoroutineBasedSession(AbstractSession):
 
 
     def _on_main_task_finish(self):
     def _on_main_task_finish(self):
         self.send_task_command(dict(command='close_session'))
         self.send_task_command(dict(command='close_session'))
+        self._on_session_close()
         self.close()
         self.close()
 
 
     def send_task_command(self, command):
     def send_task_command(self, command):
@@ -134,16 +135,10 @@ class CoroutineBasedSession(AbstractSession):
             coro = self.inactive_coro_instances.pop()
             coro = self.inactive_coro_instances.pop()
             coro.close()
             coro.close()
 
 
-    def close(self, no_session_close_callback=False):
-        """关闭当前Session
-
-        :param bool no_session_close_callback: 不调用 on_session_close 会话结束的处理函数。
-            当 close 是由后端Backend调用时可能希望开启 no_session_close_callback
-        """
+    def close(self):
+        """关闭当前Session。由Backend调用"""
         self._cleanup()
         self._cleanup()
         self._closed = True
         self._closed = True
-        if not no_session_close_callback:
-            self._on_session_close()
         # todo clean
         # todo clean
 
 
     def closed(self):
     def closed(self):

+ 17 - 16
pywebio/session/threadbased.py

@@ -34,7 +34,8 @@ class ThreadBasedSession(AbstractSession):
         curr = threading.current_thread().getName()
         curr = threading.current_thread().getName()
         session = cls.thread2session.get(curr)
         session = cls.thread2session.get(curr)
         if session is None:
         if session is None:
-            raise SessionNotFoundException("Can't find current session. Maybe session closed. Did you forget to use `register_thread` ?")
+            raise SessionNotFoundException(
+                "Can't find current session. Maybe session closed. Did you forget to use `register_thread` ?")
         return session
         return session
 
 
     @staticmethod
     @staticmethod
@@ -45,8 +46,7 @@ class ThreadBasedSession(AbstractSession):
         """
         """
         :param target: 会话运行的函数
         :param target: 会话运行的函数
         :param on_task_command: 当Task内发送Command给session的时候触发的处理函数
         :param on_task_command: 当Task内发送Command给session的时候触发的处理函数
-        :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
-            需要保证会话内的所有消息都传送到了客户端
+        :param on_session_close: 会话结束的处理函数
         :param loop: 事件循环。若 on_task_command 或者 on_session_close 中有调用使用asyncio事件循环的调用,
         :param loop: 事件循环。若 on_task_command 或者 on_session_close 中有调用使用asyncio事件循环的调用,
             则需要事件循环实例来将回调在事件循环的线程中执行
             则需要事件循环实例来将回调在事件循环的线程中执行
         """
         """
@@ -70,7 +70,7 @@ class ThreadBasedSession(AbstractSession):
 
 
     def _start_main_task(self, target):
     def _start_main_task(self, target):
         assert (not asyncio.iscoroutinefunction(target)) and (not inspect.isgeneratorfunction(target)), ValueError(
         assert (not asyncio.iscoroutinefunction(target)) and (not inspect.isgeneratorfunction(target)), ValueError(
-            "In ThreadBasedSession.__init__, `target` must be a simple function, "
+            "ThreadBasedSession only accept a simple function as task function, "
             "not coroutine function or generator function. ")
             "not coroutine function or generator function. ")
 
 
         def thread_task(target):
         def thread_task(target):
@@ -80,6 +80,7 @@ class ThreadBasedSession(AbstractSession):
                 self.on_task_exception()
                 self.on_task_exception()
             finally:
             finally:
                 self.send_task_command(dict(command='close_session'))
                 self.send_task_command(dict(command='close_session'))
+                self._trigger_close_event()
                 self.close()
                 self.close()
 
 
         task_name = '%s-%s' % (target.__name__, random_str(10))
         task_name = '%s-%s' % (target.__name__, random_str(10))
@@ -96,6 +97,7 @@ class ThreadBasedSession(AbstractSession):
         """
         """
         with self._server_msg_lock:
         with self._server_msg_lock:
             self.unhandled_task_msgs.append(command)
             self.unhandled_task_msgs.append(command)
+
         if self._loop:
         if self._loop:
             self._loop.call_soon_threadsafe(self._on_task_command, self)
             self._loop.call_soon_threadsafe(self._on_task_command, self)
         else:
         else:
@@ -128,30 +130,29 @@ class ThreadBasedSession(AbstractSession):
             self.unhandled_task_msgs = []
             self.unhandled_task_msgs = []
         return msgs
         return msgs
 
 
+    def _trigger_close_event(self):
+        """触发Backend on_session_close callback"""
+        if self._loop:
+            self._loop.call_soon_threadsafe(self._on_session_close)
+        else:
+            self._on_session_close()
+
     def _cleanup(self):
     def _cleanup(self):
         self.event_mqs = {}
         self.event_mqs = {}
+
         # Don't clean unhandled_task_msgs, it may not send to client
         # Don't clean unhandled_task_msgs, it may not send to client
         # self.unhandled_task_msgs = []
         # self.unhandled_task_msgs = []
+
         for t in self.threads:
         for t in self.threads:
             del ThreadBasedSession.thread2session[t]
             del ThreadBasedSession.thread2session[t]
-            # pass
 
 
         if self.callback_mq is not None:  # 回调功能已经激活
         if self.callback_mq is not None:  # 回调功能已经激活
             self.callback_mq.put(None)  # 结束回调线程
             self.callback_mq.put(None)  # 结束回调线程
 
 
-    def close(self, no_session_close_callback=False):
-        """关闭当前Session
-
-        :param bool no_session_close_callback: 不调用 on_session_close 会话结束的处理函数。
-            当 close 是由后端Backend调用时可能希望开启 no_session_close_callback
-        """
+    def close(self):
+        """关闭当前Session。由Backend调用"""
         self._cleanup()
         self._cleanup()
         self._closed = True
         self._closed = True
-        if not no_session_close_callback:
-            if self._loop:
-                self._loop.call_soon_threadsafe(self._on_session_close)
-            else:
-                self._on_session_close()
 
 
     def closed(self):
     def closed(self):
         return self._closed
         return self._closed