threadbased.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import asyncio
  2. import inspect
  3. import logging
  4. import queue
  5. import sys
  6. import threading
  7. import traceback
  8. from .base import AbstractSession
  9. from ..exceptions import SessionNotFoundException
  10. from ..utils import random_str
  11. logger = logging.getLogger(__name__)
  12. """
  13. 基于线程的会话实现
  14. 主任务线程退出后,连接关闭,但不会清理主任务线程产生的其他线程
  15. 客户端连接关闭后,后端线程不会退出,但是再次调用
  16. todo: thread 重名
  17. """
  18. # todo 线程安全
  19. class ThreadBasedWebIOSession(AbstractSession):
  20. thread2session = {} # thread_id -> session
  21. event_mq_maxsize = 100
  22. callback_mq_maxsize = 100
  23. @classmethod
  24. def get_current_session(cls) -> "ThreadBasedWebIOSession":
  25. curr = threading.current_thread().getName()
  26. session = cls.thread2session.get(curr)
  27. if session is None:
  28. raise SessionNotFoundException("Can't find current session. Maybe session closed.")
  29. return session
  30. @staticmethod
  31. def get_current_task_id():
  32. return threading.current_thread().getName()
  33. def __init__(self, target, on_task_message=None, on_session_close=None, loop=None):
  34. """
  35. :param target_func: 会话运行的函数
  36. :param on_coro_msg: 由协程内发给session的消息的处理函数
  37. :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
  38. 需要保证会话内的所有消息都传送到了客户端
  39. :param loop: 事件循环。若on_task_message或者on_session_close中有调用使用asyncio事件循环的调用,
  40. 则需要事件循环实例来将回调在事件循环的线程中执行
  41. """
  42. self._on_task_message = on_task_message or (lambda _: None)
  43. self._on_session_close = on_session_close or (lambda: None)
  44. self._loop = loop
  45. self._server_msg_lock = threading.Lock()
  46. self.threads = [] # 当前会话的线程id集合,用户会话结束后,清理数据
  47. self.unhandled_task_msgs = []
  48. self.event_mqs = {} # thread_id -> event msg queue
  49. self._closed = False
  50. # 用于实现回调函数的注册
  51. self.callback_mq = None
  52. self.callback_thread = None
  53. self.callbacks = {} # callback_id -> (callback_func, is_mutex)
  54. self._start_main_task(target)
  55. def _start_main_task(self, target):
  56. assert (not asyncio.iscoroutinefunction(target)) and (not inspect.isgeneratorfunction(target)), ValueError(
  57. "In ThreadBasedWebIOSession.__init__, `target` must be a simple function, "
  58. "not coroutine function or generator function. ")
  59. def thread_task(target):
  60. try:
  61. target()
  62. except Exception as e:
  63. self.on_task_exception()
  64. finally:
  65. self.send_task_message(dict(command='close_session'))
  66. self.close()
  67. task_name = '%s-%s' % (target.__name__, random_str(10))
  68. thread = threading.Thread(target=thread_task, kwargs=dict(target=target),
  69. daemon=True, name=task_name)
  70. self.register_thread(thread)
  71. thread.start()
  72. def send_task_message(self, message):
  73. """向会话发送来自协程内的消息
  74. :param dict message: 消息
  75. """
  76. with self._server_msg_lock:
  77. self.unhandled_task_msgs.append(message)
  78. if self._loop:
  79. self._loop.call_soon_threadsafe(self._on_task_message, self)
  80. else:
  81. self._on_task_message(self)
  82. def next_client_event(self):
  83. name = threading.current_thread().getName()
  84. event_mq = self.get_current_session().event_mqs.get(name)
  85. return event_mq.get()
  86. def send_client_event(self, event):
  87. """向会话发送来自用户浏览器的事件️
  88. :param dict event: 事件️消息
  89. """
  90. task_id = event['coro_id']
  91. mq = self.event_mqs.get(task_id)
  92. if not mq and task_id in self.callbacks:
  93. mq = self.callback_mq
  94. if not mq:
  95. logger.error('event_mqs not found, task_id:%s', task_id)
  96. return
  97. mq.put(event)
  98. def get_task_messages(self):
  99. with self._server_msg_lock:
  100. msgs = self.unhandled_task_msgs
  101. self.unhandled_task_msgs = []
  102. return msgs
  103. def _cleanup(self):
  104. self.event_mqs = {}
  105. # Don't clean unhandled_task_msgs, it may not send to client
  106. # self.unhandled_task_msgs = []
  107. for t in self.threads:
  108. del ThreadBasedWebIOSession.thread2session[t]
  109. # pass
  110. if self.callback_mq is not None: # 回调功能已经激活
  111. self.callback_mq.put(None) # 结束回调线程
  112. def close(self, no_session_close_callback=False):
  113. """关闭当前Session
  114. :param bool no_session_close_callback: 不调用 on_session_close 会话结束的处理函数。
  115. 当 close 是由后端Backend调用时可能希望开启 no_session_close_callback
  116. """
  117. self._cleanup()
  118. self._closed = True
  119. if not no_session_close_callback:
  120. if self._loop:
  121. self._loop.call_soon_threadsafe(self._on_session_close)
  122. else:
  123. self._on_session_close()
  124. def closed(self):
  125. return self._closed
  126. def on_task_exception(self):
  127. from ..output import put_markdown # todo
  128. logger.exception('Error in coroutine executing')
  129. type, value, tb = sys.exc_info()
  130. tb_len = len(list(traceback.walk_tb(tb)))
  131. lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
  132. traceback_msg = ''.join(lines)
  133. put_markdown("发生错误:\n```\n%s\n```" % traceback_msg)
  134. def _activate_callback_env(self):
  135. """激活回调功能
  136. ThreadBasedWebIOSession的回调实现原理是:创建一个单独的线程用于接收回调事件,进而调用相关的回调函数。
  137. 当用户Task中并没有使用到回调功能时,不必开启此线程,可以节省资源
  138. """
  139. if self.callback_mq is not None: # 回调功能已经激活
  140. return
  141. self.callback_mq = queue.Queue(maxsize=self.callback_mq_maxsize)
  142. self.callback_thread = threading.Thread(target=self._dispatch_callback_event,
  143. daemon=True, name='callback-' + random_str(10))
  144. self.register_thread(self.callback_thread)
  145. self.callback_thread.start()
  146. logger.debug('Callback thread start')
  147. def _dispatch_callback_event(self):
  148. while not self.closed():
  149. event = self.callback_mq.get()
  150. if event is None: # 结束信号
  151. break
  152. callback_info = self.callbacks.get(event['coro_id'])
  153. if not callback_info:
  154. logger.error("No callback for coro_id:%s", event['coro_id'])
  155. return
  156. callback, mutex = callback_info
  157. def run(callback):
  158. try:
  159. callback(event['data'])
  160. except:
  161. ThreadBasedWebIOSession.get_current_session().on_task_exception()
  162. if mutex:
  163. run(callback)
  164. else:
  165. t = threading.Thread(target=run, kwargs=dict(callback=callback),
  166. daemon=True, name=event['coro_id'])
  167. self.register_thread(t)
  168. t.start()
  169. def register_callback(self, callback, serial_mode):
  170. """ 向Session注册一个回调函数,返回回调id
  171. Session需要保证当收到前端发送的事件消息 ``{event: "callback",coro_id: 回调id, data:...}`` 时,
  172. ``callback`` 回调函数被执行, 并传入事件消息中的 ``data`` 字段值作为参数
  173. :param bool serial_mode: 串行模式模式。若为 ``True`` ,则对于同一组件的点击事件,串行执行其回调函数
  174. """
  175. assert (not asyncio.iscoroutinefunction(callback)) and (not inspect.isgeneratorfunction(callback)), ValueError(
  176. "In ThreadBasedWebIOSession.register_callback, `callback` must be a simple function, "
  177. "not coroutine function or generator function. ")
  178. self._activate_callback_env()
  179. callback_id = 'CB-%s-%s' % (getattr(callback, '__name__', ''), random_str(10))
  180. self.callbacks[callback_id] = (callback, serial_mode)
  181. return callback_id
  182. def register_thread(self, t: threading.Thread, as_daemon=True):
  183. """注册线程,以便在线程内调用 pywebio 交互函数
  184. :param threading.Thread thread: 线程对象
  185. :param bool as_daemon: 是否将线程设置为 daemon 线程. 默认为 True
  186. """
  187. if as_daemon:
  188. t.setDaemon(True)
  189. tname = t.getName()
  190. self.threads.append(tname)
  191. self.thread2session[tname] = self
  192. event_mq = queue.Queue(maxsize=self.event_mq_maxsize)
  193. self.event_mqs[tname] = event_mq
  194. class DesignatedThreadSession(ThreadBasedWebIOSession):
  195. """以指定进程为会话"""
  196. def __init__(self, thread, on_task_message=None, loop=None):
  197. """
  198. :param on_coro_msg: 由协程内发给session的消息的处理函数
  199. :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
  200. 需要保证会话内的所有消息都传送到了客户端
  201. :param loop: 事件循环。若on_task_message或者on_session_close中有调用使用asyncio事件循环的调用,
  202. 则需要事件循环实例来将回调在事件循环的线程中执行
  203. """
  204. self._on_task_message = on_task_message or (lambda _: None)
  205. self._on_session_close = lambda: None
  206. self._loop = loop
  207. self._server_msg_lock = threading.Lock()
  208. self.threads = [] # 当前会话的线程id集合,用户会话结束后,清理数据
  209. self.unhandled_task_msgs = []
  210. self.event_mqs = {} # thread_id -> event msg queue
  211. self._closed = False
  212. # 用于实现回调函数的注册
  213. self.callback_mq = None
  214. self.callback_thread = None
  215. self.callbacks = {} # callback_id -> (callback_func, is_mutex)
  216. self.register_thread(thread, as_daemon=False)