|
@@ -52,12 +52,15 @@ class CoroutineBasedSession(AbstractSession):
|
|
raise RuntimeError("No current task found in context!")
|
|
raise RuntimeError("No current task found in context!")
|
|
return _context.current_task_id
|
|
return _context.current_task_id
|
|
|
|
|
|
- def __init__(self, coroutine_func, on_task_command=None, on_session_close=None):
|
|
|
|
|
|
+ def __init__(self, target, on_task_command=None, on_session_close=None):
|
|
"""
|
|
"""
|
|
- :param coroutine_func: 协程函数
|
|
|
|
|
|
+ :param target: 协程函数
|
|
:param on_task_command: 由协程内发给session的消息的处理函数
|
|
:param on_task_command: 由协程内发给session的消息的处理函数
|
|
:param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
|
|
:param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
|
|
"""
|
|
"""
|
|
|
|
+ assert asyncio.iscoroutinefunction(target) or inspect.isgeneratorfunction(target), ValueError(
|
|
|
|
+ "In CoroutineBasedSession accept coroutine function or generator function as task function")
|
|
|
|
+
|
|
self._on_task_command = on_task_command or (lambda _: None)
|
|
self._on_task_command = on_task_command or (lambda _: None)
|
|
self._on_session_close = on_session_close or (lambda: None)
|
|
self._on_session_close = on_session_close or (lambda: None)
|
|
self.unhandled_task_msgs = []
|
|
self.unhandled_task_msgs = []
|
|
@@ -67,7 +70,7 @@ class CoroutineBasedSession(AbstractSession):
|
|
self._closed = False
|
|
self._closed = False
|
|
self.inactive_coro_instances = [] # 待激活的协程实例列表
|
|
self.inactive_coro_instances = [] # 待激活的协程实例列表
|
|
|
|
|
|
- self.main_task = Task(coroutine_func(), session=self, on_coro_stop=self._on_main_task_finish)
|
|
|
|
|
|
+ self.main_task = Task(target(), session=self, on_coro_stop=self._on_main_task_finish)
|
|
self.coros[self.main_task.coro_id] = self.main_task
|
|
self.coros[self.main_task.coro_id] = self.main_task
|
|
|
|
|
|
self._step_task(self.main_task)
|
|
self._step_task(self.main_task)
|