فهرست منبع

fix: threadbased session block in `close()`

wangweimin 4 سال پیش
والد
کامیت
94bec9539c

+ 2 - 1
pywebio/platform/aiohttp.py

@@ -103,7 +103,8 @@ def _webio_handler(applications, cdn, websocket_settings, check_origin_func=_is_
                 pass
                 pass
             elif msg.type == web.WSMsgType.close:
             elif msg.type == web.WSMsgType.close:
                 if not close_from_session_tag:
                 if not close_from_session_tag:
-                    session.close()
+                    # close session because client disconnected to server
+                    session.close(nonblock=True)
                     logger.debug("WebSocket closed from client")
                     logger.debug("WebSocket closed from client")
 
 
         return ws
         return ws

+ 1 - 1
pywebio/platform/httpbased.py

@@ -125,7 +125,7 @@ class HttpHandler:
             logger.debug("session %s expired" % sid)
             logger.debug("session %s expired" % sid)
             session = cls._webio_sessions.get(sid)
             session = cls._webio_sessions.get(sid)
             if session:
             if session:
-                session.close()
+                session.close(nonblock=True)
                 del cls._webio_sessions[sid]
                 del cls._webio_sessions[sid]
 
 
     @classmethod
     @classmethod

+ 1 - 1
pywebio/platform/tornado.py

@@ -143,7 +143,7 @@ def _webio_handler(applications=None, cdn=True, check_origin_func=_is_same_site)
             # Session.close() is called only when connection is closed from the client.
             # Session.close() is called only when connection is closed from the client.
             # 只有在由客户端主动断开连接时,才调用 session.close()
             # 只有在由客户端主动断开连接时,才调用 session.close()
             if not self._close_from_session_tag:
             if not self._close_from_session_tag:
-                self.session.close()
+                self.session.close(nonblock=True)
             logger.debug("WebSocket closed")
             logger.debug("WebSocket closed")
 
 
     return WSHandler
     return WSHandler

+ 5 - 1
pywebio/session/base.py

@@ -105,7 +105,11 @@ class Session:
     def get_task_commands(self) -> list:
     def get_task_commands(self) -> list:
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def close(self):
+    def close(self, nonblock=False):
+        """Close current session
+
+        :param bool nonblock: Don't block thread. Used in closing from backend.
+        """
         if self._closed:
         if self._closed:
             return
             return
         self._closed = True
         self._closed = True

+ 1 - 1
pywebio/session/coroutinebased.py

@@ -148,7 +148,7 @@ class CoroutineBasedSession(Session):
             t.close()
             t.close()
         self.coros = {}  # delete session tasks
         self.coros = {}  # delete session tasks
 
 
-    def close(self):
+    def close(self, nonblock=False):
         """关闭当前Session。由Backend调用"""
         """关闭当前Session。由Backend调用"""
         if self.closed():
         if self.closed():
             return
             return

+ 20 - 13
pywebio/session/threadbased.py

@@ -125,6 +125,8 @@ class ThreadBasedSession(Session):
 
 
         task_id = self.get_current_task_id()
         task_id = self.get_current_task_id()
         event_mq = self.get_current_session().task_mqs.get(task_id)
         event_mq = self.get_current_session().task_mqs.get(task_id)
+        if event_mq is None:
+            raise SessionNotFoundException
         event = event_mq.get()
         event = event_mq.get()
         if event is None:
         if event is None:
             raise SessionClosedException
             raise SessionClosedException
@@ -144,7 +146,10 @@ class ThreadBasedSession(Session):
             logger.error('event_mqs not found, task_id:%s', task_id)
             logger.error('event_mqs not found, task_id:%s', task_id)
             return
             return
 
 
-        mq.put(event)
+        try:
+            mq.put_nowait(event)  # disable blocking, because this is call by backend
+        except queue.Full:
+            logger.error('Message queue is full, discard new messages')  # todo: alert user
 
 
     def get_task_commands(self):
     def get_task_commands(self):
         return self.unhandled_task_msgs.get()
         return self.unhandled_task_msgs.get()
@@ -156,26 +161,28 @@ class ThreadBasedSession(Session):
         else:
         else:
             self._on_session_close()
             self._on_session_close()
 
 
-    def _cleanup(self):
+    def _cleanup(self, nonblock=False):
         cls = type(self)
         cls = type(self)
+        if not nonblock:
+            self.unhandled_task_msgs.wait_empty(8)
 
 
-        self.unhandled_task_msgs.wait_empty(8)
         if not self.unhandled_task_msgs.empty():
         if not self.unhandled_task_msgs.empty():
-            logger.debug("Unhandled task messages when session close:%s", self.unhandled_task_msgs.get())
-            raise RuntimeError('There are unhandled task messages when session close!')
+            msg = self.unhandled_task_msgs.get()
+            logger.warning("%d unhandled task messages when session close. [%s]", len(msg), threading.current_thread())
 
 
         for t in self.threads:
         for t in self.threads:
+            # delete registered thread
+            # so the `get_current_session()` call in those thread will raise SessionNotFoundException
             del cls.thread2session[id(t)]
             del cls.thread2session[id(t)]
 
 
-        if self.callback_mq is not None:  # 回调功能已经激活
-            self.callback_mq.put(None)  # 结束回调线程
-
-        for mq in self.task_mqs.values():
-            mq.put(None)  # 消费端接收到None消息会抛出SessionClosedException异常
+        if self.callback_mq is not None:  # 回调功能已经激活, 结束回调线程
+            mq = queue.Queue(maxsize=1)
+            mq.put(None)
+            self.callback_mq = mq
 
 
         self.task_mqs = {}
         self.task_mqs = {}
 
 
-    def close(self):
+    def close(self, nonblock=False):
         """关闭当前Session。由Backend调用"""
         """关闭当前Session。由Backend调用"""
         # todo self._closed 会有竞争条件
         # todo self._closed 会有竞争条件
         if self.closed():
         if self.closed():
@@ -183,7 +190,7 @@ class ThreadBasedSession(Session):
 
 
         super().close()
         super().close()
 
 
-        self._cleanup()
+        self._cleanup(nonblock=nonblock)
 
 
     def _activate_callback_env(self):
     def _activate_callback_env(self):
         """激活回调功能
         """激活回调功能
@@ -200,7 +207,7 @@ class ThreadBasedSession(Session):
                                                 daemon=True, name='callback-' + random_str(10))
                                                 daemon=True, name='callback-' + random_str(10))
         # self.register_thread(self.callback_thread)
         # self.register_thread(self.callback_thread)
         self.thread2session[id(self.callback_thread)] = self  # 用于在线程内获取会话
         self.thread2session[id(self.callback_thread)] = self  # 用于在线程内获取会话
-        event_mq = queue.Queue(maxsize=self.event_mq_maxsize)  # 线程内的用户事件队列
+        event_mq = queue.Queue(maxsize=self.event_mq_maxsize)  # 回调线程内的用户事件队列
         self.task_mqs[self._get_task_id(self.callback_thread)] = event_mq
         self.task_mqs[self._get_task_id(self.callback_thread)] = event_mq
 
 
         self.callback_thread.start()
         self.callback_thread.start()