1
0

coroutinebased.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import asyncio
  2. import inspect
  3. import logging
  4. import sys
  5. import traceback
  6. from contextlib import contextmanager
  7. from .base import AbstractSession
  8. from ..exceptions import SessionNotFoundException
  9. from ..utils import random_str
  10. logger = logging.getLogger(__name__)
  11. class WebIOFuture:
  12. def __init__(self, coro=None):
  13. self.coro = coro
  14. def __iter__(self):
  15. result = yield self
  16. return result
  17. __await__ = __iter__ # make compatible with 'await' expression
  18. class _context:
  19. current_session = None # type:"AsyncBasedSession"
  20. current_task_id = None
  21. class CoroutineBasedSession(AbstractSession):
  22. """
  23. 基于协程的任务会话
  24. 当主协程任务和会话内所有通过 `run_async` 注册的协程都退出后,会话关闭。
  25. 当用户浏览器主动关闭会话,CoroutineBasedSession.close 被调用, 协程任务和会话内所有通过 `run_async` 注册的协程都被关闭。
  26. """
  27. @staticmethod
  28. def get_current_session() -> "CoroutineBasedSession":
  29. if _context.current_session is None:
  30. raise SessionNotFoundException("No current found in context!")
  31. return _context.current_session
  32. @staticmethod
  33. def get_current_task_id():
  34. if _context.current_task_id is None:
  35. raise RuntimeError("No current task found in context!")
  36. return _context.current_task_id
  37. def __init__(self, target, on_task_command=None, on_session_close=None):
  38. """
  39. :param target: 协程函数
  40. :param on_task_command: 由协程内发给session的消息的处理函数
  41. :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
  42. """
  43. assert asyncio.iscoroutinefunction(target) or inspect.isgeneratorfunction(target), ValueError(
  44. "CoroutineBasedSession accept coroutine function or generator function as task function")
  45. self._on_task_command = on_task_command or (lambda _: None)
  46. self._on_session_close = on_session_close or (lambda: None)
  47. self.unhandled_task_msgs = []
  48. self.coros = {} # coro_task_id -> coro
  49. self._closed = False
  50. self.inactive_coro_instances = [] # 待激活的协程实例列表
  51. self._not_closed_coro_cnt = 1 # 当前会话未结束运行的协程数量。当 self._not_closed_coro_cnt == 0 时,会话结束。
  52. main_task = Task(target(), session=self, on_coro_stop=self._on_task_finish)
  53. self.coros[main_task.coro_id] = main_task
  54. self._step_task(main_task)
  55. def _step_task(self, task, result=None):
  56. task.step(result)
  57. if task.task_closed and task.coro_id in self.coros:
  58. # 若task 为main task,则 task.step(result) 结束后,可能task已经结束,self.coros已被清理
  59. logger.debug('del self.coros[%s]', task.coro_id)
  60. del self.coros[task.coro_id]
  61. while self.inactive_coro_instances:
  62. coro = self.inactive_coro_instances.pop()
  63. sub_task = Task(coro, session=self, on_coro_stop=self._on_task_finish)
  64. self.coros[sub_task.coro_id] = sub_task
  65. sub_task.step()
  66. if sub_task.task_closed:
  67. logger.debug('del self.coros[%s]', sub_task.coro_id)
  68. del self.coros[sub_task.coro_id]
  69. def _on_task_finish(self):
  70. self._not_closed_coro_cnt -= 1
  71. if self._not_closed_coro_cnt <= 0:
  72. self.send_task_command(dict(command='close_session'))
  73. self._on_session_close()
  74. self.close()
  75. def send_task_command(self, command):
  76. """向会话发送来自协程内的消息
  77. :param dict command: 消息
  78. """
  79. self.unhandled_task_msgs.append(command)
  80. self._on_task_command(self)
  81. async def next_client_event(self):
  82. res = await WebIOFuture()
  83. return res
  84. def send_client_event(self, event):
  85. """向会话发送来自用户浏览器的事件️
  86. :param dict event: 事件️消息
  87. """
  88. coro_id = event['task_id']
  89. coro = self.coros.get(coro_id)
  90. if not coro:
  91. logger.error('coro not found, coro_id:%s', coro_id)
  92. return
  93. self._step_task(coro, event)
  94. def get_task_commands(self):
  95. msgs = self.unhandled_task_msgs
  96. self.unhandled_task_msgs = []
  97. return msgs
  98. def _cleanup(self):
  99. for t in self.coros.values():
  100. t.close()
  101. self.coros = {} # delete session tasks
  102. while self.inactive_coro_instances:
  103. coro = self.inactive_coro_instances.pop()
  104. coro.close()
  105. def close(self):
  106. """关闭当前Session。由Backend调用"""
  107. self._cleanup()
  108. self._closed = True
  109. # todo clean
  110. def closed(self):
  111. return self._closed
  112. def on_task_exception(self):
  113. from ..output import put_markdown # todo
  114. logger.exception('Error in coroutine executing')
  115. type, value, tb = sys.exc_info()
  116. tb_len = len(list(traceback.walk_tb(tb)))
  117. lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
  118. traceback_msg = ''.join(lines)
  119. put_markdown("发生错误:\n```\n%s\n```" % traceback_msg)
  120. def register_callback(self, callback, mutex_mode=False):
  121. """ 向Session注册一个回调函数,返回回调id
  122. :type callback: Callable or Coroutine
  123. :param callback: 回调函数. 可以是普通函数或者协程函数. 函数签名为 ``callback(data)``.
  124. :param bool mutex_mode: 互斥模式。若为 ``True`` ,则在运行回调函数过程中,无法响应同一组件的新点击事件,仅当 ``callback`` 为协程函数时有效
  125. :return str: 回调id.
  126. CoroutineBasedSession 保证当收到前端发送的事件消息 ``{event: "callback",coro_id: 回调id, data:...}`` 时,
  127. ``callback`` 回调函数被执行, 并传入事件消息中的 ``data`` 字段值作为参数
  128. """
  129. async def callback_coro():
  130. while True:
  131. event = await self.next_client_event()
  132. assert event['event'] == 'callback'
  133. coro = None
  134. if asyncio.iscoroutinefunction(callback):
  135. coro = callback(event['data'])
  136. elif inspect.isgeneratorfunction(callback):
  137. coro = asyncio.coroutine(callback)(event['data'])
  138. else:
  139. try:
  140. callback(event['data'])
  141. except:
  142. CoroutineBasedSession.get_current_session().on_task_exception()
  143. if coro is not None:
  144. if mutex_mode:
  145. await coro
  146. else:
  147. self.run_async(coro)
  148. callback_task = Task(callback_coro(), CoroutineBasedSession.get_current_session())
  149. callback_task.coro.send(None) # 激活,Non't callback.step() ,导致嵌套调用step todo 与inactive_coro_instances整合
  150. CoroutineBasedSession.get_current_session().coros[callback_task.coro_id] = callback_task
  151. return callback_task.coro_id
  152. def run_async(self, coro_obj):
  153. """异步运行协程对象。可以在协程内调用 PyWebIO 交互函数
  154. :param coro_obj: 协程对象
  155. """
  156. self.inactive_coro_instances.append(coro_obj)
  157. self._not_closed_coro_cnt += 1
  158. async def run_asyncio_coroutine(self, coro_obj):
  159. """若会话线程和运行事件的线程不是同一个线程,需要用 asyncio_coroutine 来运行asyncio中的协程"""
  160. res = await WebIOFuture(coro=coro_obj)
  161. return res
  162. class Task:
  163. @contextmanager
  164. def session_context(self):
  165. """
  166. >>> with session_context():
  167. ... res = self.coros[-1].send(data)
  168. """
  169. # todo issue: with 语句可能发生嵌套,导致内层with退出时,将属性置空
  170. _context.current_session = self.session
  171. _context.current_task_id = self.coro_id
  172. try:
  173. yield
  174. finally:
  175. _context.current_session = None
  176. _context.current_task_id = None
  177. @staticmethod
  178. def gen_coro_id(coro=None):
  179. name = 'coro'
  180. if hasattr(coro, '__name__'):
  181. name = coro.__name__
  182. return '%s-%s' % (name, random_str(10))
  183. def __init__(self, coro, session: CoroutineBasedSession, on_coro_stop=None):
  184. self.session = session
  185. self.coro = coro
  186. self.coro_id = None
  187. self.result = None
  188. self.task_closed = False # 任务完毕/取消
  189. self.on_coro_stop = on_coro_stop or (lambda: None)
  190. self.coro_id = self.gen_coro_id(self.coro)
  191. self.pending_futures = {} # id(future) -> future
  192. logger.debug('Task[%s] created ', self.coro_id)
  193. def step(self, result=None):
  194. coro_yield = None
  195. with self.session_context():
  196. try:
  197. coro_yield = self.coro.send(result)
  198. except StopIteration as e:
  199. if len(e.args) == 1:
  200. self.result = e.args[0]
  201. self.task_closed = True
  202. logger.debug('Task[%s] finished', self.coro_id)
  203. self.on_coro_stop()
  204. except Exception as e:
  205. self.session.on_task_exception()
  206. self.task_closed = True
  207. self.on_coro_stop()
  208. future = None
  209. if isinstance(coro_yield, WebIOFuture):
  210. if coro_yield.coro:
  211. future = asyncio.run_coroutine_threadsafe(coro_yield.coro, asyncio.get_event_loop())
  212. elif coro_yield is not None:
  213. future = coro_yield
  214. if not self.session.closed() and hasattr(future, 'add_done_callback'):
  215. future.add_done_callback(self._tornado_future_callback)
  216. self.pending_futures[id(future)] = future
  217. def _tornado_future_callback(self, future):
  218. if not future.cancelled():
  219. del self.pending_futures[id(future)]
  220. self.step(future.result())
  221. def close(self):
  222. logger.debug('Task[%s] closed', self.coro_id)
  223. self.coro.close()
  224. while self.pending_futures:
  225. _, f = self.pending_futures.popitem()
  226. f.cancel()
  227. self.task_closed = True
  228. def __del__(self):
  229. if not self.task_closed:
  230. logger.warning('Task[%s] not finished when destroy', self.coro_id)