|
@@ -125,6 +125,8 @@ class ThreadBasedSession(Session):
|
|
|
|
|
|
task_id = self.get_current_task_id()
|
|
task_id = self.get_current_task_id()
|
|
event_mq = self.get_current_session().task_mqs.get(task_id)
|
|
event_mq = self.get_current_session().task_mqs.get(task_id)
|
|
|
|
+ if event_mq is None:
|
|
|
|
+ raise SessionNotFoundException
|
|
event = event_mq.get()
|
|
event = event_mq.get()
|
|
if event is None:
|
|
if event is None:
|
|
raise SessionClosedException
|
|
raise SessionClosedException
|
|
@@ -144,7 +146,10 @@ class ThreadBasedSession(Session):
|
|
logger.error('event_mqs not found, task_id:%s', task_id)
|
|
logger.error('event_mqs not found, task_id:%s', task_id)
|
|
return
|
|
return
|
|
|
|
|
|
- mq.put(event)
|
|
|
|
|
|
+ try:
|
|
|
|
+ mq.put_nowait(event) # disable blocking, because this is call by backend
|
|
|
|
+ except queue.Full:
|
|
|
|
+ logger.error('Message queue is full, discard new messages') # todo: alert user
|
|
|
|
|
|
def get_task_commands(self):
|
|
def get_task_commands(self):
|
|
return self.unhandled_task_msgs.get()
|
|
return self.unhandled_task_msgs.get()
|
|
@@ -156,26 +161,28 @@ class ThreadBasedSession(Session):
|
|
else:
|
|
else:
|
|
self._on_session_close()
|
|
self._on_session_close()
|
|
|
|
|
|
- def _cleanup(self):
|
|
|
|
|
|
+ def _cleanup(self, nonblock=False):
|
|
cls = type(self)
|
|
cls = type(self)
|
|
|
|
+ if not nonblock:
|
|
|
|
+ self.unhandled_task_msgs.wait_empty(8)
|
|
|
|
|
|
- self.unhandled_task_msgs.wait_empty(8)
|
|
|
|
if not self.unhandled_task_msgs.empty():
|
|
if not self.unhandled_task_msgs.empty():
|
|
- logger.debug("Unhandled task messages when session close:%s", self.unhandled_task_msgs.get())
|
|
|
|
- raise RuntimeError('There are unhandled task messages when session close!')
|
|
|
|
|
|
+ msg = self.unhandled_task_msgs.get()
|
|
|
|
+ logger.warning("%d unhandled task messages when session close. [%s]", len(msg), threading.current_thread())
|
|
|
|
|
|
for t in self.threads:
|
|
for t in self.threads:
|
|
|
|
+ # delete registered thread
|
|
|
|
+ # so the `get_current_session()` call in those thread will raise SessionNotFoundException
|
|
del cls.thread2session[id(t)]
|
|
del cls.thread2session[id(t)]
|
|
|
|
|
|
- if self.callback_mq is not None: # 回调功能已经激活
|
|
|
|
- self.callback_mq.put(None) # 结束回调线程
|
|
|
|
-
|
|
|
|
- for mq in self.task_mqs.values():
|
|
|
|
- mq.put(None) # 消费端接收到None消息会抛出SessionClosedException异常
|
|
|
|
|
|
+ if self.callback_mq is not None: # 回调功能已经激活, 结束回调线程
|
|
|
|
+ mq = queue.Queue(maxsize=1)
|
|
|
|
+ mq.put(None)
|
|
|
|
+ self.callback_mq = mq
|
|
|
|
|
|
self.task_mqs = {}
|
|
self.task_mqs = {}
|
|
|
|
|
|
- def close(self):
|
|
|
|
|
|
+ def close(self, nonblock=False):
|
|
"""关闭当前Session。由Backend调用"""
|
|
"""关闭当前Session。由Backend调用"""
|
|
# todo self._closed 会有竞争条件
|
|
# todo self._closed 会有竞争条件
|
|
if self.closed():
|
|
if self.closed():
|
|
@@ -183,7 +190,7 @@ class ThreadBasedSession(Session):
|
|
|
|
|
|
super().close()
|
|
super().close()
|
|
|
|
|
|
- self._cleanup()
|
|
|
|
|
|
+ self._cleanup(nonblock=nonblock)
|
|
|
|
|
|
def _activate_callback_env(self):
|
|
def _activate_callback_env(self):
|
|
"""激活回调功能
|
|
"""激活回调功能
|
|
@@ -200,7 +207,7 @@ class ThreadBasedSession(Session):
|
|
daemon=True, name='callback-' + random_str(10))
|
|
daemon=True, name='callback-' + random_str(10))
|
|
# self.register_thread(self.callback_thread)
|
|
# self.register_thread(self.callback_thread)
|
|
self.thread2session[id(self.callback_thread)] = self # 用于在线程内获取会话
|
|
self.thread2session[id(self.callback_thread)] = self # 用于在线程内获取会话
|
|
- event_mq = queue.Queue(maxsize=self.event_mq_maxsize) # 线程内的用户事件队列
|
|
|
|
|
|
+ event_mq = queue.Queue(maxsize=self.event_mq_maxsize) # 回调线程内的用户事件队列
|
|
self.task_mqs[self._get_task_id(self.callback_thread)] = event_mq
|
|
self.task_mqs[self._get_task_id(self.callback_thread)] = event_mq
|
|
|
|
|
|
self.callback_thread.start()
|
|
self.callback_thread.start()
|