threadbased.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import asyncio
  2. import inspect
  3. import logging
  4. import queue
  5. import sys
  6. import threading
  7. import traceback
  8. from .base import AbstractSession
  9. from ..exceptions import SessionNotFoundException, SessionClosedException
  10. from ..utils import random_str, LimitedSizeQueue
  11. logger = logging.getLogger(__name__)
  12. """
  13. 基于线程的会话实现
  14. 主任务线程退出后,连接关闭,但不会清理主任务线程产生的其他线程
  15. 客户端连接关闭后,后端线程不会退出,但是再次调用输入输出函数会引发异常
  16. todo: thread 重名
  17. """
  18. # todo 线程安全
  19. class ThreadBasedSession(AbstractSession):
  20. thread2session = {} # thread_id -> session
  21. unhandled_task_mq_maxsize = 1000
  22. event_mq_maxsize = 100
  23. callback_mq_maxsize = 100
  24. _active_session_cnt = 0
  25. @classmethod
  26. def active_session_count(cls):
  27. return cls._active_session_cnt
  28. @classmethod
  29. def get_current_session(cls) -> "ThreadBasedSession":
  30. curr = id(threading.current_thread())
  31. session = cls.thread2session.get(curr)
  32. if session is None:
  33. raise SessionNotFoundException(
  34. "Can't find current session. Maybe session closed. Did you forget to use `register_thread` ?")
  35. return session
  36. @classmethod
  37. def get_current_task_id(cls):
  38. return cls._get_task_id(threading.current_thread())
  39. @staticmethod
  40. def _get_task_id(thread: threading.Thread):
  41. tname = getattr(thread, '_target', 'task')
  42. tname = getattr(tname, '__name__', tname)
  43. return '%s-%s' % (tname, id(thread))
  44. def __init__(self, target, on_task_command=None, on_session_close=None, loop=None):
  45. """
  46. :param target: 会话运行的函数
  47. :param on_task_command: 当Task内发送Command给session的时候触发的处理函数
  48. :param on_session_close: 会话结束的处理函数
  49. :param loop: 事件循环。若 on_task_command 或者 on_session_close 中有调用使用asyncio事件循环的调用,
  50. 则需要事件循环实例来将回调在事件循环的线程中执行
  51. """
  52. assert (not asyncio.iscoroutinefunction(target)) and (not inspect.isgeneratorfunction(target)), ValueError(
  53. "ThreadBasedSession only accept a simple function as task function, "
  54. "not coroutine function or generator function. ")
  55. ThreadBasedSession._active_session_cnt += 1
  56. self._on_task_command = on_task_command or (lambda _: None)
  57. self._on_session_close = on_session_close or (lambda: None)
  58. self._loop = loop
  59. self.threads = [] # 注册到当前会话的线程集合
  60. self.unhandled_task_msgs = LimitedSizeQueue(maxsize=self.unhandled_task_mq_maxsize)
  61. self.task_mqs = {} # task_id -> event msg queue
  62. self._closed = False
  63. # 用于实现回调函数的注册
  64. self.callback_mq = None
  65. self.callback_thread = None
  66. self.callbacks = {} # callback_id -> (callback_func, is_mutex)
  67. self._start_main_task(target)
  68. def _start_main_task(self, target):
  69. def main_task(target):
  70. try:
  71. target()
  72. except Exception as e:
  73. self.on_task_exception()
  74. finally:
  75. for t in self.threads:
  76. if t.is_alive() and t is not threading.current_thread():
  77. t.join()
  78. self.send_task_command(dict(command='close_session'))
  79. self._trigger_close_event()
  80. self.close()
  81. main_task.__name__ = getattr(target, '__name__', 'main')
  82. thread = threading.Thread(target=main_task, kwargs=dict(target=target),
  83. daemon=True, name='main_task')
  84. self.register_thread(thread)
  85. thread.start()
  86. def send_task_command(self, command):
  87. """向会话发送来自协程内的消息
  88. :param dict command: 消息
  89. """
  90. self.unhandled_task_msgs.put(command)
  91. if self._loop:
  92. self._loop.call_soon_threadsafe(self._on_task_command, self)
  93. else:
  94. self._on_task_command(self)
  95. def next_client_event(self):
  96. task_id = self.get_current_task_id()
  97. event_mq = self.get_current_session().task_mqs.get(task_id)
  98. return event_mq.get()
  99. def send_client_event(self, event):
  100. """向会话发送来自用户浏览器的事件️
  101. :param dict event: 事件️消息
  102. """
  103. task_id = event['task_id']
  104. mq = self.task_mqs.get(task_id)
  105. if not mq and task_id in self.callbacks:
  106. mq = self.callback_mq
  107. if not mq:
  108. logger.error('event_mqs not found, task_id:%s', task_id)
  109. return
  110. mq.put(event)
  111. def get_task_commands(self):
  112. return self.unhandled_task_msgs.get()
  113. def _trigger_close_event(self):
  114. """触发Backend on_session_close callback"""
  115. if self._loop:
  116. self._loop.call_soon_threadsafe(self._on_session_close)
  117. else:
  118. self._on_session_close()
  119. def _cleanup(self):
  120. self.task_mqs = {}
  121. if not self.unhandled_task_msgs.empty():
  122. raise RuntimeError('There are unhandled task msgs when session close!')
  123. for t in self.threads:
  124. del ThreadBasedSession.thread2session[id(t)]
  125. if self.callback_mq is not None: # 回调功能已经激活
  126. self.callback_mq.put(None) # 结束回调线程
  127. ThreadBasedSession._active_session_cnt -= 1
  128. def close(self):
  129. """关闭当前Session。由Backend调用"""
  130. if self._closed:
  131. return
  132. self._closed = True
  133. self._cleanup()
  134. def closed(self):
  135. return self._closed
  136. def on_task_exception(self):
  137. from ..output import put_markdown # todo
  138. logger.exception('Error in coroutine executing')
  139. type, value, tb = sys.exc_info()
  140. tb_len = len(list(traceback.walk_tb(tb)))
  141. lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
  142. traceback_msg = ''.join(lines)
  143. try:
  144. put_markdown("发生错误:\n```\n%s\n```" % traceback_msg)
  145. except:
  146. pass
  147. def _activate_callback_env(self):
  148. """激活回调功能
  149. ThreadBasedSession 的回调实现原理是:创建一个单独的线程用于接收回调事件,进而调用相关的回调函数。
  150. 当用户Task中并没有使用到回调功能时,不必开启此线程,可以节省资源
  151. """
  152. if self.callback_mq is not None: # 回调功能已经激活
  153. return
  154. self.callback_mq = queue.Queue(maxsize=self.callback_mq_maxsize)
  155. self.callback_thread = threading.Thread(target=self._dispatch_callback_event,
  156. daemon=True, name='callback-' + random_str(10))
  157. self.register_thread(self.callback_thread)
  158. self.callback_thread.start()
  159. logger.debug('Callback thread start')
  160. def _dispatch_callback_event(self):
  161. while not self.closed():
  162. event = self.callback_mq.get()
  163. if event is None: # 结束信号
  164. break
  165. callback_info = self.callbacks.get(event['task_id'])
  166. if not callback_info:
  167. logger.error("No callback for callback_id:%s", event['task_id'])
  168. return
  169. callback, mutex = callback_info
  170. def run(callback):
  171. try:
  172. callback(event['data'])
  173. except:
  174. # 子类可能会重写 get_current_session ,所以不要用 ThreadBasedSession.get_current_session 来调用
  175. self.get_current_session().on_task_exception()
  176. if mutex:
  177. run(callback)
  178. else:
  179. t = threading.Thread(target=run, kwargs=dict(callback=callback),
  180. daemon=True)
  181. self.register_thread(t)
  182. t.start()
  183. def register_callback(self, callback, serial_mode=False):
  184. """ 向Session注册一个回调函数,返回回调id
  185. Session需要保证当收到前端发送的事件消息 ``{event: "callback",task_id: 回调id, data:...}`` 时,
  186. ``callback`` 回调函数被执行, 并传入事件消息中的 ``data`` 字段值作为参数
  187. :param bool serial_mode: 串行模式模式。若为 ``True`` ,则对于同一组件的点击事件,串行执行其回调函数
  188. """
  189. assert (not asyncio.iscoroutinefunction(callback)) and (not inspect.isgeneratorfunction(callback)), ValueError(
  190. "In ThreadBasedSession.register_callback, `callback` must be a simple function, "
  191. "not coroutine function or generator function. ")
  192. self._activate_callback_env()
  193. callback_id = 'CB-%s-%s' % (getattr(callback, '__name__', ''), random_str(10))
  194. self.callbacks[callback_id] = (callback, serial_mode)
  195. return callback_id
  196. def register_thread(self, t: threading.Thread):
  197. """将线程注册到当前会话,以便在线程内调用 pywebio 交互函数。
  198. 会话会一直保持直到所有通过 `register_thread` 注册的线程以及当前会话的主任务线程退出
  199. :param threading.Thread thread: 线程对象
  200. """
  201. self.threads.append(t)
  202. self.thread2session[id(t)] = self
  203. event_mq = queue.Queue(maxsize=self.event_mq_maxsize)
  204. self.task_mqs[self._get_task_id(t)] = event_mq
  205. class ScriptModeSession(ThreadBasedSession):
  206. """Script mode的会话实现"""
  207. @classmethod
  208. def get_current_session(cls) -> "ScriptModeSession":
  209. if cls.instance is None:
  210. raise SessionNotFoundException("Can't find current session. It might be a bug.")
  211. if cls.instance.closed():
  212. raise SessionClosedException()
  213. return cls.instance
  214. @classmethod
  215. def get_current_task_id(cls):
  216. task_id = super().get_current_task_id()
  217. session = cls.get_current_session()
  218. if task_id not in session.task_mqs:
  219. session.register_thread(threading.current_thread())
  220. return task_id
  221. instance = None
  222. def __init__(self, thread, on_task_command=None, loop=None):
  223. """
  224. :param on_task_command: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
  225. 需要保证会话内的所有消息都传送到了客户端
  226. :param loop: 事件循环。若 on_task_command 或者on_session_close中有调用使用asyncio事件循环的调用,
  227. 则需要事件循环实例来将回调在事件循环的线程中执行
  228. """
  229. if ScriptModeSession.instance is not None:
  230. raise RuntimeError("ScriptModeSession can only be created once.")
  231. ScriptModeSession.instance = self
  232. ThreadBasedSession._active_session_cnt += 1
  233. self._on_task_command = on_task_command or (lambda _: None)
  234. self._on_session_close = lambda: None
  235. self._loop = loop
  236. self.threads = [] # 当前会话的线程
  237. self.unhandled_task_msgs = LimitedSizeQueue(maxsize=self.unhandled_task_mq_maxsize)
  238. self.task_mqs = {} # task_id -> event msg queue
  239. self._closed = False
  240. # 用于实现回调函数的注册
  241. self.callback_mq = None
  242. self.callback_thread = None
  243. self.callbacks = {} # callback_id -> (callback_func, is_mutex)
  244. tid = id(thread)
  245. event_mq = queue.Queue(maxsize=self.event_mq_maxsize)
  246. self.task_mqs[tid] = event_mq