coroutinebased.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import asyncio
  2. import logging
  3. import threading
  4. from contextlib import contextmanager
  5. from functools import partial
  6. from .base import Session
  7. from ..exceptions import SessionNotFoundException, SessionClosedException, SessionException
  8. from ..utils import random_str, isgeneratorfunction, iscoroutinefunction
  9. logger = logging.getLogger(__name__)
  10. class WebIOFuture:
  11. def __init__(self, coro=None):
  12. self.coro = coro
  13. def __iter__(self):
  14. result = yield self
  15. return result
  16. __await__ = __iter__ # make compatible with 'await' expression
  17. class _context:
  18. current_session = None # type:"CoroutineBasedSession"
  19. current_task_id = None
  20. class CoroutineBasedSession(Session):
  21. """
  22. 基于协程的任务会话
  23. 当主协程任务和会话内所有通过 `run_async` 注册的协程都退出后,会话关闭。
  24. 当用户浏览器主动关闭会话,CoroutineBasedSession.close 被调用, 协程任务和会话内所有通过 `run_async` 注册的协程都被关闭。
  25. """
  26. # 运行事件循环的线程id
  27. # 用于在 CoroutineBasedSession.get_current_session() 判断调用方是否合法
  28. # Tornado backend时,在创建第一个CoroutineBasedSession时初始化
  29. # Flask backend时,在platform.flaskrun_event_loop()时初始化
  30. event_loop_thread_id = None
  31. @classmethod
  32. def get_current_session(cls) -> "CoroutineBasedSession":
  33. if _context.current_session is None or cls.event_loop_thread_id != threading.current_thread().ident:
  34. raise SessionNotFoundException("No session found in current context!")
  35. if _context.current_session.closed():
  36. raise SessionClosedException
  37. return _context.current_session
  38. @staticmethod
  39. def get_current_task_id():
  40. if _context.current_task_id is None:
  41. raise RuntimeError("No current task found in context!")
  42. return _context.current_task_id
  43. def __init__(self, target, session_info, on_task_command=None, on_session_close=None):
  44. """
  45. :param target: 协程函数
  46. :param on_task_command: 由协程内发给session的消息的处理函数
  47. :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
  48. """
  49. assert iscoroutinefunction(target) or isgeneratorfunction(target), ValueError(
  50. "CoroutineBasedSession accept coroutine function or generator function as task function")
  51. super().__init__(session_info)
  52. cls = type(self)
  53. self._on_task_command = on_task_command or (lambda _: None)
  54. self._on_session_close = on_session_close or (lambda: None)
  55. # 当前会话未被Backend处理的消息
  56. self.unhandled_task_msgs = []
  57. # 在创建第一个CoroutineBasedSession时 event_loop_thread_id 还未被初始化
  58. # 则当前线程即为运行 event loop 的线程
  59. if cls.event_loop_thread_id is None:
  60. cls.event_loop_thread_id = threading.current_thread().ident
  61. # 会话内的协程任务
  62. self.coros = {} # coro_task_id -> Task()
  63. self._closed = False
  64. self._need_keep_alive = False
  65. # 当前会话未结束运行(已创建和正在运行的)的协程数量。当 _alive_coro_cnt 变为 0 时,会话结束。
  66. self._alive_coro_cnt = 1
  67. main_task = Task(self._start_main_task(target), session=self, on_coro_stop=self._on_task_finish)
  68. self.coros[main_task.coro_id] = main_task
  69. self._step_task(main_task)
  70. async def _start_main_task(self, target):
  71. await target()
  72. if self.need_keep_alive():
  73. from ..session import hold
  74. await hold()
  75. def _step_task(self, task, result=None):
  76. asyncio.get_event_loop().call_soon_threadsafe(partial(task.step, result))
  77. def _on_task_finish(self, task: "Task"):
  78. self._alive_coro_cnt -= 1
  79. if task.coro_id in self.coros:
  80. logger.debug('del self.coros[%s]', task.coro_id)
  81. del self.coros[task.coro_id]
  82. if self._alive_coro_cnt <= 0 and not self.closed():
  83. self.send_task_command(dict(command='close_session'))
  84. self._on_session_close()
  85. self.close()
  86. def send_task_command(self, command):
  87. """向会话发送来自协程内的消息
  88. :param dict command: 消息
  89. """
  90. if self.closed():
  91. raise SessionClosedException()
  92. self.unhandled_task_msgs.append(command)
  93. self._on_task_command(self)
  94. async def next_client_event(self):
  95. # 函数开始不需要判断 self.closed()
  96. # 如果会话关闭,对 get_current_session().next_client_event() 的调用会抛出SessionClosedException
  97. return await WebIOFuture()
  98. def send_client_event(self, event):
  99. """向会话发送来自用户浏览器的事件️
  100. :param dict event: 事件️消息
  101. """
  102. coro_id = event['task_id']
  103. coro = self.coros.get(coro_id)
  104. if not coro:
  105. logger.error('coro not found, coro_id:%s', coro_id)
  106. return
  107. self._step_task(coro, event)
  108. def get_task_commands(self):
  109. msgs = self.unhandled_task_msgs
  110. self.unhandled_task_msgs = []
  111. return msgs
  112. def _cleanup(self):
  113. for t in list(self.coros.values()): # t.close() may cause self.coros changed size
  114. t.step(SessionClosedException, throw_exp=True)
  115. # in case that the task catch the SessionClosedException, we need to close it manually
  116. t.close()
  117. self.coros = {} # delete session tasks
  118. # reset the reference, to avoid circular reference
  119. self._on_session_close = None
  120. self._on_task_command = None
  121. def close(self, nonblock=False):
  122. """关闭当前Session。由Backend调用"""
  123. if self.closed():
  124. return
  125. super().close()
  126. self._cleanup()
  127. def register_callback(self, callback, mutex_mode=False):
  128. """ 向Session注册一个回调函数,返回回调id
  129. :type callback: Callable or Coroutine
  130. :param callback: 回调函数. 函数签名为 ``callback(data)``. ``data`` 参数为回调事件的值
  131. :param bool mutex_mode: 互斥模式。若为 ``True`` ,则在运行回调函数过程中,无法响应同一组件(callback_id相同)的新点击事件,仅当 ``callback`` 为协程函数时有效
  132. :return str: 回调id.
  133. """
  134. async def callback_coro():
  135. while True:
  136. try:
  137. event = await self.next_client_event()
  138. except SessionClosedException:
  139. return
  140. assert event['event'] == 'callback'
  141. coro = None
  142. if iscoroutinefunction(callback):
  143. coro = callback(event['data'])
  144. elif isgeneratorfunction(callback):
  145. coro = asyncio.coroutine(callback)(event['data'])
  146. else:
  147. try:
  148. res = callback(event['data'])
  149. if asyncio.iscoroutine(res):
  150. coro = res
  151. else:
  152. del res # `res` maybe pywebio.io_ctrl.Output, so need release `res`
  153. except Exception:
  154. self.on_task_exception()
  155. if coro is not None:
  156. if mutex_mode:
  157. await coro
  158. else:
  159. self.run_async(coro)
  160. cls = type(self)
  161. callback_task = Task(callback_coro(), cls.get_current_session())
  162. # Activate task
  163. # Don't callback.step(), it will result in recursive calls to step()
  164. # todo: integrate with inactive_coro_instances
  165. callback_task.coro.send(None)
  166. cls.get_current_session().coros[callback_task.coro_id] = callback_task
  167. self._need_keep_alive = True
  168. return callback_task.coro_id
  169. def run_async(self, coro_obj):
  170. """异步运行协程对象。可以在协程内调用 PyWebIO 交互函数
  171. :param coro_obj: 协程对象
  172. :return: An instance of `TaskHandler` is returned, which can be used later to close the task.
  173. """
  174. assert asyncio.iscoroutine(coro_obj), '`run_async()` only accept coroutine object'
  175. self._alive_coro_cnt += 1
  176. task = Task(coro_obj, session=self, on_coro_stop=self._on_task_finish)
  177. self.coros[task.coro_id] = task
  178. asyncio.get_event_loop().call_soon_threadsafe(task.step)
  179. return task.task_handle()
  180. async def run_asyncio_coroutine(self, coro_obj):
  181. """若会话线程和运行事件的线程不是同一个线程,需要用 asyncio_coroutine 来运行asyncio中的协程"""
  182. assert asyncio.iscoroutine(coro_obj), '`run_asyncio_coroutine()` only accept coroutine object'
  183. res = await WebIOFuture(coro=coro_obj)
  184. return res
  185. def need_keep_alive(self) -> bool:
  186. return self._need_keep_alive
  187. class TaskHandler:
  188. """The handler of coroutine task
  189. See also: `run_async() <pywebio.session.run_async>`
  190. """
  191. def __init__(self, close, closed):
  192. self._close = close
  193. self._closed = closed
  194. def close(self):
  195. """Close the coroutine task."""
  196. return self._close()
  197. def closed(self) -> bool:
  198. """Returns a bool stating whether the coroutine task is closed. """
  199. return self._closed()
  200. class Task:
  201. @contextmanager
  202. def session_context(self):
  203. """
  204. >>> with session_context():
  205. ... res = self.coros[-1].send(data)
  206. """
  207. # todo issue: with 语句可能发生嵌套,导致内层with退出时,将属性置空
  208. _context.current_session = self.session
  209. _context.current_task_id = self.coro_id
  210. try:
  211. yield
  212. finally:
  213. _context.current_session = None
  214. _context.current_task_id = None
  215. @staticmethod
  216. def gen_coro_id(coro=None):
  217. """生成协程id"""
  218. name = 'coro'
  219. if hasattr(coro, '__name__'):
  220. name = coro.__name__
  221. return '%s-%s' % (name, random_str(10))
  222. def __init__(self, coro, session: CoroutineBasedSession, on_coro_stop=None):
  223. """
  224. :param coro: 协程对象
  225. :param session: 创建该Task的会话实例
  226. :param on_coro_stop: 任务结束(正常结束或外部调用Task.close)时运行的回调
  227. """
  228. self.session = session
  229. self.coro = coro
  230. self.result = None
  231. self.task_closed = False # 任务完毕/取消
  232. self.on_coro_stop = on_coro_stop or (lambda _: None)
  233. self.coro_id = self.gen_coro_id(self.coro)
  234. self.pending_futures = {} # id(future) -> future
  235. logger.debug('Task[%s] created ', self.coro_id)
  236. def step(self, result=None, throw_exp=False):
  237. """激活协程
  238. :param any result: 向协程传入的数据
  239. :param bool throw_exp: 是否向协程引发异常,为 True 时, result 参数为相应的异常对象
  240. """
  241. coro_yield = None
  242. with self.session_context():
  243. try:
  244. if throw_exp:
  245. coro_yield = self.coro.throw(result)
  246. else:
  247. coro_yield = self.coro.send(result)
  248. except StopIteration as e:
  249. if len(e.args) == 1:
  250. self.result = e.args[0]
  251. self.close()
  252. logger.debug('Task[%s] finished', self.coro_id)
  253. except Exception as e:
  254. if not isinstance(e, SessionException):
  255. self.session.on_task_exception()
  256. self.close()
  257. if coro_yield is None:
  258. return
  259. future = None
  260. if isinstance(coro_yield, WebIOFuture):
  261. if coro_yield.coro:
  262. future = asyncio.run_coroutine_threadsafe(coro_yield.coro, asyncio.get_event_loop())
  263. else:
  264. future = coro_yield
  265. if not self.session.closed() and hasattr(future, 'add_done_callback'):
  266. future.add_done_callback(self._wakeup)
  267. self.pending_futures[id(future)] = future
  268. def _wakeup(self, future):
  269. if not future.cancelled():
  270. del self.pending_futures[id(future)]
  271. self.step(future.result())
  272. def close(self):
  273. if self.task_closed:
  274. return
  275. self.task_closed = True
  276. self.coro.close()
  277. while self.pending_futures:
  278. _, f = self.pending_futures.popitem()
  279. f.cancel()
  280. self.on_coro_stop(self)
  281. self.on_coro_stop = None # avoid circular reference
  282. self.session = None
  283. logger.debug('Task[%s] closed', self.coro_id)
  284. def __del__(self):
  285. if not self.task_closed:
  286. logger.warning('Task[%s] was destroyed but it is pending!', self.coro_id)
  287. def task_handle(self):
  288. handle = TaskHandler(close=self.close, closed=lambda: self.task_closed)
  289. return handle