asyncbased.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import asyncio
  2. import inspect
  3. import logging
  4. import sys
  5. import traceback
  6. from contextlib import contextmanager
  7. from .base import AbstractSession
  8. from ..utils import random_str
  9. logger = logging.getLogger(__name__)
  10. class WebIOFuture:
  11. def __init__(self, coro=None):
  12. self.coro = coro
  13. def __iter__(self):
  14. result = yield self
  15. return result
  16. __await__ = __iter__ # make compatible with 'await' expression
  17. class _context:
  18. current_session = None # type:"AsyncBasedSession"
  19. current_task_id = None
  20. class AsyncBasedSession(AbstractSession):
  21. """
  22. 一个PyWebIO任务会话, 由不同的后端Backend创建并维护
  23. WebIOSession是不同的后端Backend与协程交互的桥梁:
  24. 后端Backend在接收到用户浏览器的数据后,会通过调用 ``send_client_msg`` 来通知会话,进而由WebIOSession驱动协程的运行。
  25. 协程内在调用输入输出函数后,会调用 ``send_coro_msg`` 向会话发送输入输出消息指令, WebIOSession将其保存并留给后端Backend处理。
  26. .. note::
  27. 后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
  28. """
  29. @staticmethod
  30. def get_current_session() -> "AsyncBasedSession":
  31. if _context.current_session is None:
  32. raise RuntimeError("No current found in context!")
  33. return _context.current_session
  34. @staticmethod
  35. def get_current_task_id():
  36. if _context.current_task_id is None:
  37. raise RuntimeError("No current task found in context!")
  38. return _context.current_task_id
  39. def __init__(self, coroutine_func, on_task_message=None, on_session_close=None):
  40. """
  41. :param coro_func: 协程函数
  42. :param on_coro_msg: 由协程内发给session的消息的处理函数
  43. :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
  44. """
  45. self._on_task_message = on_task_message or (lambda _: None)
  46. self._on_session_close = on_session_close or (lambda: None)
  47. self.unhandled_task_msgs = []
  48. self.coros = {} # coro_id -> coro
  49. self._closed = False
  50. self.inactive_coro_instances = [] # 待激活的协程实例列表
  51. self.main_task = Task(coroutine_func(), session=self, on_coro_stop=self._on_main_task_finish)
  52. self.coros[self.main_task.coro_id] = self.main_task
  53. self._step_task(self.main_task)
  54. def _step_task(self, task, result=None):
  55. task.step(result)
  56. if task.task_finished and task.coro_id in self.coros:
  57. # 若task 为main task,则 task.step(result) 结束后,可能task已经结束,self.coros已被清理
  58. logger.debug('del self.coros[%s]', task.coro_id)
  59. del self.coros[task.coro_id]
  60. while self.inactive_coro_instances and not self.main_task.task_finished:
  61. coro = self.inactive_coro_instances.pop()
  62. sub_task = Task(coro, session=self)
  63. self.coros[sub_task.coro_id] = sub_task
  64. sub_task.step()
  65. if sub_task.task_finished:
  66. logger.debug('del self.coros[%s]', sub_task.coro_id)
  67. del self.coros[sub_task.coro_id]
  68. def _on_main_task_finish(self):
  69. self.send_task_message(dict(command='close_session'))
  70. self.close()
  71. def send_task_message(self, message):
  72. """向会话发送来自协程内的消息
  73. :param dict message: 消息
  74. """
  75. self.unhandled_task_msgs.append(message)
  76. self._on_task_message(self)
  77. async def next_client_event(self):
  78. res = await WebIOFuture()
  79. return res
  80. def send_client_event(self, event):
  81. """向会话发送来自用户浏览器的事件️
  82. :param dict event: 事件️消息
  83. """
  84. coro_id = event['coro_id']
  85. coro = self.coros.get(coro_id)
  86. if not coro:
  87. logger.error('coro not found, coro_id:%s', coro_id)
  88. return
  89. self._step_task(coro, event)
  90. def get_task_messages(self):
  91. msgs = self.unhandled_task_msgs
  92. self.unhandled_task_msgs = []
  93. return msgs
  94. def _cleanup(self):
  95. for t in self.coros.values():
  96. t.close()
  97. self.coros = {} # delete session tasks
  98. while self.inactive_coro_instances:
  99. coro = self.inactive_coro_instances.pop()
  100. coro.close()
  101. def close(self, no_session_close_callback=False):
  102. """关闭当前Session
  103. :param bool no_session_close_callback: 不调用 on_session_close 会话结束的处理函数。
  104. 当 close 是由后端Backend调用时可能希望开启 no_session_close_callback
  105. """
  106. self._cleanup()
  107. self._closed = True
  108. if not no_session_close_callback:
  109. self._on_session_close()
  110. # todo clean
  111. def closed(self):
  112. return self._closed
  113. def on_task_exception(self):
  114. from ..output import put_markdown # todo
  115. logger.exception('Error in coroutine executing')
  116. type, value, tb = sys.exc_info()
  117. tb_len = len(list(traceback.walk_tb(tb)))
  118. lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
  119. traceback_msg = ''.join(lines)
  120. put_markdown("发生错误:\n```\n%s\n```" % traceback_msg)
  121. def register_callback(self, callback, mutex_mode):
  122. """ 向Session注册一个回调函数,返回回调id
  123. :type callback: Callable or Coroutine
  124. :param callback: 回调函数. 可以是普通函数或者协程函数. 函数签名为 ``callback(data)``.
  125. :return str: 回调id.
  126. AsyncBasedSession保证当收到前端发送的事件消息 ``{event: "callback",coro_id: 回调id, data:...}`` 时,
  127. ``callback`` 回调函数被执行, 并传入事件消息中的 ``data`` 字段值作为参数
  128. """
  129. async def callback_coro():
  130. while True:
  131. event = await self.next_client_event()
  132. assert event['event'] == 'callback'
  133. coro = None
  134. if asyncio.iscoroutinefunction(callback):
  135. coro = callback(event['data'])
  136. elif inspect.isgeneratorfunction(callback):
  137. coro = asyncio.coroutine(callback)(event['data'])
  138. else:
  139. try:
  140. callback(event['data'])
  141. except:
  142. AsyncBasedSession.get_current_session().on_task_exception()
  143. if coro is not None:
  144. if mutex_mode:
  145. await coro
  146. else:
  147. self.run_async(coro)
  148. callback_task = Task(callback_coro(), AsyncBasedSession.get_current_session())
  149. callback_task.coro.send(None) # 激活,Non't callback.step() ,导致嵌套调用step todo 与inactive_coro_instances整合
  150. AsyncBasedSession.get_current_session().coros[callback_task.coro_id] = callback_task
  151. return callback_task.coro_id
  152. def run_async(self, coro_obj):
  153. self.inactive_coro_instances.append(coro_obj)
  154. async def asyncio_coroutine(self, coro):
  155. """若会话线程和运行事件的线程不是同一个线程,需要用 asyncio_coroutine 来运行asyncio中的协程"""
  156. res = await WebIOFuture(coro=coro)
  157. return res
  158. class Task:
  159. @contextmanager
  160. def session_context(self):
  161. """
  162. >>> with session_context():
  163. ... res = self.coros[-1].send(data)
  164. """
  165. # todo issue: with 语句可能发生嵌套,导致内层with退出时,将属性置空
  166. _context.current_session = self.session
  167. _context.current_task_id = self.coro_id
  168. try:
  169. yield
  170. finally:
  171. _context.current_session = None
  172. _context.current_task_id = None
  173. @staticmethod
  174. def gen_coro_id(coro=None):
  175. name = 'coro'
  176. if hasattr(coro, '__name__'):
  177. name = coro.__name__
  178. return '%s-%s' % (name, random_str(10))
  179. def __init__(self, coro, session: AsyncBasedSession, on_coro_stop=None):
  180. self.session = session
  181. self.coro = coro
  182. self.coro_id = None
  183. self.result = None
  184. self.task_finished = False # 任务完毕/取消
  185. self.on_coro_stop = on_coro_stop or (lambda: None)
  186. self.coro_id = self.gen_coro_id(self.coro)
  187. self.pending_futures = {} # id(future) -> future
  188. logger.debug('Task[%s] created ', self.coro_id)
  189. def step(self, result=None):
  190. coro_yield = None
  191. with self.session_context():
  192. try:
  193. coro_yield = self.coro.send(result)
  194. except StopIteration as e:
  195. if len(e.args) == 1:
  196. self.result = e.args[0]
  197. self.task_finished = True
  198. logger.debug('Task[%s] finished', self.coro_id)
  199. self.on_coro_stop()
  200. except Exception as e:
  201. self.session.on_task_exception()
  202. future = None
  203. if isinstance(coro_yield, WebIOFuture):
  204. if coro_yield.coro:
  205. future = asyncio.run_coroutine_threadsafe(coro_yield.coro, asyncio.get_event_loop())
  206. elif coro_yield is not None:
  207. future = coro_yield
  208. if not self.session.closed() and hasattr(future, 'add_done_callback'):
  209. future.add_done_callback(self._tornado_future_callback)
  210. self.pending_futures[id(future)] = future
  211. def _tornado_future_callback(self, future):
  212. if not future.cancelled():
  213. del self.pending_futures[id(future)]
  214. self.step(future.result())
  215. def close(self):
  216. logger.debug('Task[%s] closed', self.coro_id)
  217. self.coro.close()
  218. while self.pending_futures:
  219. _, f = self.pending_futures.popitem()
  220. f.cancel()
  221. self.task_finished = True
  222. def __del__(self):
  223. if not self.task_finished:
  224. logger.warning('Task[%s] not finished when destroy', self.coro_id)