|
@@ -5,9 +5,9 @@ import threading
|
|
|
import traceback
|
|
|
from functools import wraps
|
|
|
|
|
|
-from .base import AbstractSession
|
|
|
+from .base import Session
|
|
|
from ..exceptions import SessionNotFoundException, SessionClosedException, SessionException
|
|
|
-from ..utils import random_str, LimitedSizeQueue, isgeneratorfunction, iscoroutinefunction, catch_exp_call, \
|
|
|
+from ..utils import random_str, LimitedSizeQueue, isgeneratorfunction, iscoroutinefunction, \
|
|
|
get_function_name
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
@@ -22,19 +22,13 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
# todo 线程安全
|
|
|
-class ThreadBasedSession(AbstractSession):
|
|
|
+class ThreadBasedSession(Session):
|
|
|
thread2session = {} # thread_id -> session
|
|
|
|
|
|
unhandled_task_mq_maxsize = 1000
|
|
|
event_mq_maxsize = 100
|
|
|
callback_mq_maxsize = 100
|
|
|
|
|
|
- _active_session_cnt = 0
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def active_session_count(cls):
|
|
|
- return cls._active_session_cnt
|
|
|
-
|
|
|
@classmethod
|
|
|
def get_current_session(cls) -> "ThreadBasedSession":
|
|
|
curr = id(threading.current_thread())
|
|
@@ -67,16 +61,12 @@ class ThreadBasedSession(AbstractSession):
|
|
|
"ThreadBasedSession only accept a simple function as task function, "
|
|
|
"not coroutine function or generator function. ")
|
|
|
|
|
|
- type(self)._active_session_cnt += 1
|
|
|
+ super().__init__(session_info)
|
|
|
|
|
|
- self.info = session_info
|
|
|
self._on_task_command = on_task_command or (lambda _: None)
|
|
|
self._on_session_close = on_session_close or (lambda: None)
|
|
|
self._loop = loop
|
|
|
|
|
|
- # 会话结束时运行的函数
|
|
|
- self.deferred_functions = []
|
|
|
-
|
|
|
self.threads = [] # 注册到当前会话的线程集合
|
|
|
self.unhandled_task_msgs = LimitedSizeQueue(maxsize=self.unhandled_task_mq_maxsize)
|
|
|
|
|
@@ -188,24 +178,15 @@ class ThreadBasedSession(AbstractSession):
|
|
|
|
|
|
self.task_mqs = {}
|
|
|
|
|
|
- cls._active_session_cnt -= 1
|
|
|
-
|
|
|
def close(self):
|
|
|
"""关闭当前Session。由Backend调用"""
|
|
|
# todo self._closed 会有竞争条件
|
|
|
- if self._closed:
|
|
|
+ if self.closed():
|
|
|
return
|
|
|
- self._closed = True
|
|
|
-
|
|
|
- self._cleanup()
|
|
|
|
|
|
- self.deferred_functions.reverse()
|
|
|
- while self.deferred_functions:
|
|
|
- func = self.deferred_functions.pop()
|
|
|
- catch_exp_call(func, logger)
|
|
|
+ super().close()
|
|
|
|
|
|
- def closed(self):
|
|
|
- return self._closed
|
|
|
+ self._cleanup()
|
|
|
|
|
|
def on_task_exception(self):
|
|
|
from ..output import put_markdown # todo
|
|
@@ -294,10 +275,6 @@ class ThreadBasedSession(AbstractSession):
|
|
|
event_mq = queue.Queue(maxsize=self.event_mq_maxsize) # 线程内的用户事件队列
|
|
|
self.task_mqs[self._get_task_id(t)] = event_mq
|
|
|
|
|
|
- def defer_call(self, func):
|
|
|
- """设置会话结束时调用的函数。可以用于资源清理。"""
|
|
|
- self.deferred_functions.append(func)
|
|
|
-
|
|
|
|
|
|
class ScriptModeSession(ThreadBasedSession):
|
|
|
"""Script mode的会话实现"""
|