threadbased.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import logging
  2. import queue
  3. import threading
  4. from functools import wraps
  5. from .base import Session
  6. from ..exceptions import SessionNotFoundException, SessionClosedException, SessionException
  7. from ..utils import random_str, LimitedSizeQueue, isgeneratorfunction, iscoroutinefunction, \
  8. get_function_name
  9. logger = logging.getLogger(__name__)
  10. """
  11. 基于线程的会话实现
  12. 当任务函数返回并且会话内所有的通过 register_thread(thread) 注册的线程都退出后,会话结束,连接关闭。
  13. 正在等待PyWebIO输入的线程会在输入函数中抛出SessionClosedException异常,
  14. 其他线程若调用PyWebIO输入输出函数会引发异常SessionException
  15. """
  16. # todo 线程安全
  17. class ThreadBasedSession(Session):
  18. thread2session = {} # thread_id -> session
  19. unhandled_task_mq_maxsize = 1000
  20. event_mq_maxsize = 100
  21. callback_mq_maxsize = 100
  22. @classmethod
  23. def get_current_session(cls) -> "ThreadBasedSession":
  24. curr = id(threading.current_thread())
  25. session = cls.thread2session.get(curr)
  26. if session is None:
  27. raise SessionNotFoundException("Can't find current session. "
  28. "Maybe session closed or forget to use `register_thread()`.")
  29. return session
  30. @classmethod
  31. def get_current_task_id(cls):
  32. return cls._get_task_id(threading.current_thread())
  33. @staticmethod
  34. def _get_task_id(thread: threading.Thread):
  35. tname = getattr(thread, '_target', 'task')
  36. tname = getattr(tname, '__name__', tname)
  37. return '%s-%s' % (tname, id(thread))
  38. def __init__(self, target, session_info, on_task_command=None, on_session_close=None, loop=None):
  39. """
  40. :param target: 会话运行的函数. 为None时表示Script mode
  41. :param on_task_command: 当Task内发送Command给session的时候触发的处理函数
  42. :param on_session_close: 会话结束的处理函数
  43. :param loop: 事件循环。若 on_task_command 或者 on_session_close 中有调用使用asyncio事件循环的调用,
  44. 则需要事件循环实例来将回调在事件循环的线程中执行
  45. """
  46. assert target is None or (not iscoroutinefunction(target)) and (not isgeneratorfunction(target)), ValueError(
  47. "ThreadBasedSession only accept a simple function as task function, "
  48. "not coroutine function or generator function. ")
  49. super().__init__(session_info)
  50. self._on_task_command = on_task_command or (lambda _: None)
  51. self._on_session_close = on_session_close or (lambda: None)
  52. self._loop = loop
  53. self.threads = [] # 注册到当前会话的线程集合
  54. self.unhandled_task_msgs = LimitedSizeQueue(maxsize=self.unhandled_task_mq_maxsize)
  55. self.task_mqs = {} # task_id -> event msg queue
  56. self._closed = False
  57. # 用于实现回调函数的注册
  58. self.callback_mq = None
  59. self.callback_thread = None
  60. self.callbacks = {} # callback_id -> (callback_func, is_mutex)
  61. if target is not None:
  62. self._start_main_task(target)
  63. def _start_main_task(self, target):
  64. @wraps(target)
  65. def main_task(target):
  66. try:
  67. target()
  68. except Exception as e:
  69. if not isinstance(e, SessionException):
  70. self.on_task_exception()
  71. finally:
  72. for t in self.threads:
  73. if t.is_alive() and t is not threading.current_thread():
  74. t.join()
  75. try:
  76. if self.need_keep_alive():
  77. from ..session import hold
  78. hold()
  79. else:
  80. self.send_task_command(dict(command='close_session'))
  81. except SessionException: # ignore SessionException error
  82. pass
  83. finally:
  84. self._trigger_close_event()
  85. self.close()
  86. thread = threading.Thread(target=main_task, kwargs=dict(target=target),
  87. daemon=True, name='main_task')
  88. self.register_thread(thread)
  89. thread.start()
  90. def send_task_command(self, command):
  91. """向会话发送来自pywebio应用的消息
  92. :param dict command: 消息
  93. """
  94. if self.closed():
  95. raise SessionClosedException()
  96. self.unhandled_task_msgs.put(command)
  97. if self._loop:
  98. self._loop.call_soon_threadsafe(self._on_task_command, self)
  99. else:
  100. self._on_task_command(self)
  101. def next_client_event(self):
  102. # 函数开始不需要判断 self.closed()
  103. # 如果会话关闭,对 get_current_session().next_client_event() 的调用会抛出SessionNotFoundException
  104. task_id = self.get_current_task_id()
  105. event_mq = self.get_current_session().task_mqs.get(task_id)
  106. if event_mq is None:
  107. raise SessionNotFoundException
  108. event = event_mq.get()
  109. if event is None:
  110. raise SessionClosedException
  111. return event
  112. def send_client_event(self, event):
  113. """向会话发送来自用户浏览器的事件️
  114. :param dict event: 事件️消息
  115. """
  116. task_id = event['task_id']
  117. mq = self.task_mqs.get(task_id)
  118. if not mq and task_id in self.callbacks:
  119. mq = self.callback_mq
  120. if not mq:
  121. logger.error('event_mqs not found, task_id:%s', task_id)
  122. return
  123. try:
  124. mq.put_nowait(event) # disable blocking, because this is call by backend
  125. except queue.Full:
  126. logger.error('Message queue is full, discard new messages') # todo: alert user
  127. def get_task_commands(self):
  128. return self.unhandled_task_msgs.get()
  129. def _trigger_close_event(self):
  130. """触发Backend on_session_close callback"""
  131. if self.closed():
  132. return
  133. if self._loop:
  134. self._loop.call_soon_threadsafe(self._on_session_close)
  135. else:
  136. self._on_session_close()
  137. def _cleanup(self, nonblock=False):
  138. cls = type(self)
  139. if not nonblock:
  140. self.unhandled_task_msgs.wait_empty(8)
  141. if not self.unhandled_task_msgs.empty():
  142. msg = self.unhandled_task_msgs.get()
  143. logger.warning("%d unhandled task messages when session close. [%s]", len(msg), threading.current_thread())
  144. for t in self.threads:
  145. # delete registered thread
  146. # so the `get_current_session()` call in those thread will raise SessionNotFoundException
  147. del cls.thread2session[id(t)]
  148. if self.callback_thread:
  149. del cls.thread2session[id(self.callback_thread)]
  150. def try_best_to_add_item_to_mq(mq, item, try_count=10):
  151. for _ in range(try_count):
  152. try:
  153. mq.put(item, block=False)
  154. return True
  155. except queue.Full:
  156. try:
  157. mq.get(block=False)
  158. except queue.Empty:
  159. pass
  160. if self.callback_mq is not None: # 回调功能已经激活, 结束回调线程
  161. try_best_to_add_item_to_mq(self.callback_mq, None)
  162. for mq in self.task_mqs.values():
  163. try_best_to_add_item_to_mq(mq, None) # 消费端接收到None消息会抛出SessionClosedException异常
  164. self.task_mqs = {}
  165. def close(self, nonblock=False):
  166. """关闭当前Session。由Backend调用"""
  167. # todo self._closed 会有竞争条件
  168. if self.closed():
  169. return
  170. super().close()
  171. self._cleanup(nonblock=nonblock)
  172. def _activate_callback_env(self):
  173. """激活回调功能
  174. ThreadBasedSession 的回调实现原理是:创建一个单独的线程用于接收回调事件,进而调用相关的回调函数。
  175. 当用户Task中并没有使用到回调功能时,不必开启此线程,可以节省资源
  176. """
  177. if self.callback_mq is not None: # 回调功能已经激活
  178. return
  179. self.callback_mq = queue.Queue(maxsize=self.callback_mq_maxsize)
  180. self.callback_thread = threading.Thread(target=self._dispatch_callback_event,
  181. daemon=True, name='callback-' + random_str(10))
  182. # self.register_thread(self.callback_thread)
  183. self.thread2session[id(self.callback_thread)] = self # 用于在线程内获取会话
  184. event_mq = queue.Queue(maxsize=self.event_mq_maxsize) # 回调线程内的用户事件队列
  185. self.task_mqs[self._get_task_id(self.callback_thread)] = event_mq
  186. self.callback_thread.start()
  187. logger.debug('Callback thread start')
  188. def _dispatch_callback_event(self):
  189. while not self.closed():
  190. event = self.callback_mq.get()
  191. if event is None: # 结束信号
  192. logger.debug('Callback thread exit')
  193. break
  194. callback_info = self.callbacks.get(event['task_id'])
  195. if not callback_info:
  196. logger.error("No callback for callback_id:%s", event['task_id'])
  197. return
  198. callback, mutex = callback_info
  199. @wraps(callback)
  200. def run(callback):
  201. try:
  202. callback(event['data'])
  203. except Exception as e:
  204. # 子类可能会重写 get_current_session ,所以不要用 ThreadBasedSession.get_current_session 来调用
  205. if not isinstance(e, SessionException):
  206. self.on_task_exception()
  207. # todo: good to have -> clean up from `register_thread()`
  208. if mutex:
  209. run(callback)
  210. else:
  211. t = threading.Thread(target=run, kwargs=dict(callback=callback),
  212. daemon=True)
  213. self.register_thread(t)
  214. t.start()
  215. def register_callback(self, callback, serial_mode=False):
  216. """ 向Session注册一个回调函数,返回回调id
  217. :param Callable callback: 回调函数. 函数签名为 ``callback(data)``. ``data`` 参数为回调事件的值
  218. :param bool serial_mode: 串行模式模式。若为 ``True`` ,则对于同一组件的点击事件,串行执行其回调函数
  219. """
  220. assert (not iscoroutinefunction(callback)) and (not isgeneratorfunction(callback)), ValueError(
  221. "In ThreadBasedSession.register_callback, `callback` must be a simple function, "
  222. "not coroutine function or generator function. ")
  223. self._activate_callback_env()
  224. callback_id = 'CB-%s-%s' % (get_function_name(callback, 'callback'), random_str(10))
  225. self.callbacks[callback_id] = (callback, serial_mode)
  226. return callback_id
  227. def register_thread(self, t: threading.Thread):
  228. """将线程注册到当前会话,以便在线程内调用 pywebio 交互函数。
  229. 会话会一直保持直到所有通过 `register_thread` 注册的线程以及当前会话的主任务线程退出
  230. :param threading.Thread thread: 线程对象
  231. """
  232. self.threads.append(t) # 保存 registered thread,用于主任务线程退出后等待注册线程结束
  233. self.thread2session[id(t)] = self # 用于在线程内获取会话
  234. event_mq = queue.Queue(maxsize=self.event_mq_maxsize) # 线程内的用户事件队列
  235. self.task_mqs[self._get_task_id(t)] = event_mq
  236. def need_keep_alive(self) -> bool:
  237. # if callback thread is activated, then the session need to keep alive
  238. return self.callback_thread is not None
  239. class ScriptModeSession(ThreadBasedSession):
  240. """Script mode的会话实现"""
  241. @classmethod
  242. def get_current_session(cls) -> "ScriptModeSession":
  243. if cls.instance is None:
  244. raise SessionNotFoundException("Can't find current session. It might be a bug.")
  245. if cls.instance.closed():
  246. raise SessionClosedException()
  247. return cls.instance
  248. @classmethod
  249. def get_current_task_id(cls):
  250. task_id = super().get_current_task_id()
  251. session = cls.get_current_session()
  252. if task_id not in session.task_mqs:
  253. session.register_thread(threading.current_thread())
  254. return task_id
  255. instance = None
  256. def __init__(self, thread, session_info, on_task_command=None, loop=None):
  257. """
  258. :param thread: 第一次调用PyWebIO交互函数的线程 todo 貌似本参数并不必要
  259. :param on_task_command: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
  260. 需要保证会话内的所有消息都传送到了客户端
  261. :param loop: 事件循环。若 on_task_command 或者on_session_close中有调用使用asyncio事件循环的调用,
  262. 则需要事件循环实例来将回调在事件循环的线程中执行
  263. """
  264. if ScriptModeSession.instance is not None:
  265. raise RuntimeError("ScriptModeSession can only be created once.")
  266. ScriptModeSession.instance = self
  267. super().__init__(target=None, session_info=session_info, on_task_command=on_task_command, loop=loop)
  268. tid = id(thread)
  269. event_mq = queue.Queue(maxsize=self.event_mq_maxsize)
  270. self.task_mqs[tid] = event_mq