threadbased.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import logging
  2. import queue
  3. import sys
  4. import threading
  5. import traceback
  6. import asyncio, inspect
  7. from .base import AbstractSession
  8. from ..utils import random_str
  9. logger = logging.getLogger(__name__)
  10. """
  11. 基于线程的会话实现
  12. 主任务线程退出后,连接关闭,但不会清理主任务线程产生的其他线程
  13. 客户端连接关闭后,后端线程不会退出,但是再次调用
  14. todo: thread 重名
  15. """
  16. # todo 线程安全
  17. class ThreadBasedWebIOSession(AbstractSession):
  18. thread2session = {} # thread_id -> session
  19. event_mq_maxsize = 100
  20. callback_mq_maxsize = 100
  21. @classmethod
  22. def get_current_session(cls) -> "ThreadBasedWebIOSession":
  23. curr = threading.current_thread().getName()
  24. session = cls.thread2session.get(curr)
  25. if session is None:
  26. raise RuntimeError("Can't find current session. Maybe session closed.")
  27. return session
  28. @staticmethod
  29. def get_current_task_id():
  30. return threading.current_thread().getName()
  31. def __init__(self, target, on_task_message=None, on_session_close=None, loop=None):
  32. """
  33. :param target_func: 会话运行的函数
  34. :param on_coro_msg: 由协程内发给session的消息的处理函数
  35. :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
  36. 需要保证会话内的所有消息都传送到了客户端
  37. """
  38. self._on_task_message = on_task_message or (lambda _: None)
  39. self._on_session_close = on_session_close or (lambda: None)
  40. self._loop = loop
  41. self._server_msg_lock = threading.Lock()
  42. self.threads = [] # 当前会话的线程id集合,用户会话结束后,清理数据
  43. self.unhandled_task_msgs = []
  44. self.event_mqs = {} # thread_id -> event msg queue
  45. self._closed = False
  46. # 用于实现回调函数的注册
  47. self.callback_mq = None
  48. self.callback_thread = None
  49. self.callbacks = {} # callback_id -> (callback_func, is_mutex)
  50. main_task = self._new_thread_task(target, on_close=self.close)
  51. self.register_thread(main_task)
  52. main_task.start()
  53. def _new_thread_task(self, target, on_close=None):
  54. def thread_task(target):
  55. try:
  56. target()
  57. except Exception as e:
  58. self.on_task_exception()
  59. finally:
  60. if on_close:
  61. on_close()
  62. task_name = '%s-%s' % (target.__name__, random_str(10))
  63. thread = threading.Thread(target=thread_task, kwargs=dict(target=target),
  64. daemon=True, name=task_name)
  65. return thread
  66. def send_task_message(self, message):
  67. """向会话发送来自协程内的消息
  68. :param dict message: 消息
  69. """
  70. with self._server_msg_lock:
  71. self.unhandled_task_msgs.append(message)
  72. if self._loop:
  73. self._loop.call_soon_threadsafe(self._on_task_message, self)
  74. else:
  75. self._on_task_message(self)
  76. def next_client_event(self):
  77. name = threading.current_thread().getName()
  78. event_mq = self.get_current_session().event_mqs.get(name)
  79. return event_mq.get()
  80. def send_client_event(self, event):
  81. """向会话发送来自用户浏览器的事件️
  82. :param dict event: 事件️消息
  83. """
  84. task_id = event['coro_id']
  85. mq = self.event_mqs.get(task_id)
  86. if not mq and task_id in self.callbacks:
  87. mq = self.callback_mq
  88. if not mq:
  89. logger.error('event_mqs not found, task_id:%s', task_id)
  90. return
  91. mq.put(event)
  92. def get_task_messages(self):
  93. with self._server_msg_lock:
  94. msgs = self.unhandled_task_msgs
  95. self.unhandled_task_msgs = []
  96. return msgs
  97. def _cleanup(self):
  98. self.event_mqs = {}
  99. # Don't clean unhandled_task_msgs, it may not send to client
  100. # self.unhandled_task_msgs = []
  101. for t in self.threads:
  102. del ThreadBasedWebIOSession.thread2session[t]
  103. # pass
  104. if self.callback_mq is not None: # 回调功能已经激活
  105. self.callback_mq.put(None) # 结束回调线程
  106. def close(self, no_session_close_callback=False):
  107. """关闭当前Session
  108. :param bool no_session_close_callback: 不调用 on_session_close 会话结束的处理函数。
  109. 当 close 是由后端Backend调用时可能希望开启 no_session_close_callback
  110. """
  111. self._cleanup()
  112. self._closed = True
  113. if not no_session_close_callback:
  114. if self._loop:
  115. self._loop.call_soon_threadsafe(self._on_session_close)
  116. else:
  117. self._on_session_close()
  118. def closed(self):
  119. return self._closed
  120. def on_task_exception(self):
  121. from ..output import put_markdown # todo
  122. logger.exception('Error in coroutine executing')
  123. type, value, tb = sys.exc_info()
  124. tb_len = len(list(traceback.walk_tb(tb)))
  125. lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
  126. traceback_msg = ''.join(lines)
  127. put_markdown("发生错误:\n```\n%s\n```" % traceback_msg)
  128. def _activate_callback_env(self):
  129. """激活回调功能
  130. ThreadBasedWebIOSession的回调实现原理是:创建一个单独的线程用于接收回调事件,进而调用相关的回调函数。
  131. 当用户Task中并没有使用到回调功能时,不必开启此线程,可以节省资源
  132. """
  133. if self.callback_mq is not None: # 回调功能已经激活
  134. return
  135. self.callback_mq = queue.Queue(maxsize=self.callback_mq_maxsize)
  136. self.callback_thread = threading.Thread(target=self._dispatch_callback_event,
  137. daemon=True, name='callback-' + random_str(10))
  138. self.register_thread(self.callback_thread)
  139. self.callback_thread.start()
  140. logger.debug('Callback thread start')
  141. def _dispatch_callback_event(self):
  142. while not self.closed():
  143. event = self.callback_mq.get()
  144. if event is None: # 结束信号
  145. break
  146. callback_info = self.callbacks.get(event['coro_id'])
  147. if not callback_info:
  148. logger.error("No callback for coro_id:%s", event['coro_id'])
  149. return
  150. callback, mutex = callback_info
  151. def run(callback):
  152. try:
  153. callback(event['data'])
  154. except:
  155. ThreadBasedWebIOSession.get_current_session().on_task_exception()
  156. if mutex:
  157. run(callback)
  158. else:
  159. t = threading.Thread(target=run, kwargs=dict(callback=callback),
  160. daemon=True, name=event['coro_id'])
  161. self.register_thread(t)
  162. t.start()
  163. def register_callback(self, callback, serial_mode):
  164. """ 向Session注册一个回调函数,返回回调id
  165. Session需要保证当收到前端发送的事件消息 ``{event: "callback",coro_id: 回调id, data:...}`` 时,
  166. ``callback`` 回调函数被执行, 并传入事件消息中的 ``data`` 字段值作为参数
  167. :param bool serial_mode: 串行模式模式。若为 ``True`` ,则对于同一组件的点击事件,串行执行其回调函数
  168. """
  169. assert (not asyncio.iscoroutinefunction(callback)) and (not inspect.isgeneratorfunction(callback)), ValueError(
  170. "In ThreadBasedWebIOSession.register_callback, `callback` must be a simple function, "
  171. "not coroutine function or generator function. ")
  172. self._activate_callback_env()
  173. callback_id = 'CB-%s-%s' % (getattr(callback, '__name__', ''), random_str(10))
  174. self.callbacks[callback_id] = (callback, serial_mode)
  175. return callback_id
  176. def register_thread(self, t: threading.Thread, as_daemon=True):
  177. """注册线程,以便在线程内调用 pywebio 交互函数"""
  178. if as_daemon:
  179. t.setDaemon(True)
  180. tname = t.getName()
  181. self.threads.append(tname)
  182. self.thread2session[tname] = self
  183. event_mq = queue.Queue(maxsize=self.event_mq_maxsize)
  184. self.event_mqs[tname] = event_mq
  185. return event_mq