|
@@ -4,7 +4,7 @@ import sys
|
|
import threading
|
|
import threading
|
|
import traceback
|
|
import traceback
|
|
from contextlib import contextmanager
|
|
from contextlib import contextmanager
|
|
-
|
|
|
|
|
|
+from functools import partial
|
|
from .base import AbstractSession
|
|
from .base import AbstractSession
|
|
from ..exceptions import SessionNotFoundException, SessionClosedException, SessionException
|
|
from ..exceptions import SessionNotFoundException, SessionClosedException, SessionException
|
|
from ..utils import random_str, isgeneratorfunction, iscoroutinefunction, catch_exp_call
|
|
from ..utils import random_str, isgeneratorfunction, iscoroutinefunction, catch_exp_call
|
|
@@ -34,18 +34,24 @@ class CoroutineBasedSession(AbstractSession):
|
|
|
|
|
|
当主协程任务和会话内所有通过 `run_async` 注册的协程都退出后,会话关闭。
|
|
当主协程任务和会话内所有通过 `run_async` 注册的协程都退出后,会话关闭。
|
|
当用户浏览器主动关闭会话,CoroutineBasedSession.close 被调用, 协程任务和会话内所有通过 `run_async` 注册的协程都被关闭。
|
|
当用户浏览器主动关闭会话,CoroutineBasedSession.close 被调用, 协程任务和会话内所有通过 `run_async` 注册的协程都被关闭。
|
|
|
|
+
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
+ # 运行事件循环的线程id
|
|
|
|
+ # 用于在 CoroutineBasedSession.get_current_session() 判断调用方是否合法
|
|
|
|
+ # Tornado backend时,在创建第一个CoroutineBasedSession时初始化
|
|
|
|
+ # Flask backend时,在platform.flaskrun_event_loop()时初始化
|
|
|
|
+ event_loop_thread_id = None
|
|
|
|
+
|
|
_active_session_cnt = 0
|
|
_active_session_cnt = 0
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def active_session_count(cls):
|
|
def active_session_count(cls):
|
|
return cls._active_session_cnt
|
|
return cls._active_session_cnt
|
|
|
|
|
|
- @staticmethod
|
|
|
|
- def get_current_session() -> "CoroutineBasedSession":
|
|
|
|
- if _context.current_session is None or \
|
|
|
|
- _context.current_session.session_thread_id != threading.current_thread().ident:
|
|
|
|
|
|
+ @classmethod
|
|
|
|
+ def get_current_session(cls) -> "CoroutineBasedSession":
|
|
|
|
+ if _context.current_session is None or cls.event_loop_thread_id != threading.current_thread().ident:
|
|
raise SessionNotFoundException("No session found in current context!")
|
|
raise SessionNotFoundException("No session found in current context!")
|
|
|
|
|
|
if _context.current_session.closed():
|
|
if _context.current_session.closed():
|
|
@@ -79,8 +85,10 @@ class CoroutineBasedSession(AbstractSession):
|
|
# 当前会话未被Backend处理的消息
|
|
# 当前会话未被Backend处理的消息
|
|
self.unhandled_task_msgs = []
|
|
self.unhandled_task_msgs = []
|
|
|
|
|
|
- # 创建会话的线程id。当前会话只能在本线程中使用
|
|
|
|
- self.session_thread_id = threading.current_thread().ident
|
|
|
|
|
|
+ # 在创建第一个CoroutineBasedSession时 event_loop_thread_id 还未被初始化
|
|
|
|
+ # 则当前线程即为运行 event loop 的线程
|
|
|
|
+ if CoroutineBasedSession.event_loop_thread_id is None:
|
|
|
|
+ CoroutineBasedSession.event_loop_thread_id = threading.current_thread().ident
|
|
|
|
|
|
# 会话内的协程任务
|
|
# 会话内的协程任务
|
|
self.coros = {} # coro_task_id -> Task()
|
|
self.coros = {} # coro_task_id -> Task()
|
|
@@ -96,7 +104,7 @@ class CoroutineBasedSession(AbstractSession):
|
|
self._step_task(main_task)
|
|
self._step_task(main_task)
|
|
|
|
|
|
def _step_task(self, task, result=None):
|
|
def _step_task(self, task, result=None):
|
|
- task.step(result)
|
|
|
|
|
|
+ asyncio.get_event_loop().call_soon_threadsafe(partial(task.step, result))
|
|
|
|
|
|
def _on_task_finish(self, task: "Task"):
|
|
def _on_task_finish(self, task: "Task"):
|
|
self._alive_coro_cnt -= 1
|
|
self._alive_coro_cnt -= 1
|
|
@@ -140,7 +148,6 @@ class CoroutineBasedSession(AbstractSession):
|
|
if not coro:
|
|
if not coro:
|
|
logger.error('coro not found, coro_id:%s', coro_id)
|
|
logger.error('coro not found, coro_id:%s', coro_id)
|
|
return
|
|
return
|
|
-
|
|
|
|
self._step_task(coro, event)
|
|
self._step_task(coro, event)
|
|
|
|
|
|
def get_task_commands(self):
|
|
def get_task_commands(self):
|
|
@@ -193,7 +200,11 @@ class CoroutineBasedSession(AbstractSession):
|
|
|
|
|
|
async def callback_coro():
|
|
async def callback_coro():
|
|
while True:
|
|
while True:
|
|
- event = await self.next_client_event()
|
|
|
|
|
|
+ try:
|
|
|
|
+ event = await self.next_client_event()
|
|
|
|
+ except SessionClosedException:
|
|
|
|
+ return
|
|
|
|
+
|
|
assert event['event'] == 'callback'
|
|
assert event['event'] == 'callback'
|
|
coro = None
|
|
coro = None
|
|
if iscoroutinefunction(callback):
|
|
if iscoroutinefunction(callback):
|
|
@@ -204,7 +215,7 @@ class CoroutineBasedSession(AbstractSession):
|
|
try:
|
|
try:
|
|
callback(event['data'])
|
|
callback(event['data'])
|
|
except:
|
|
except:
|
|
- CoroutineBasedSession.get_current_session().on_task_exception()
|
|
|
|
|
|
+ self.on_task_exception()
|
|
|
|
|
|
if coro is not None:
|
|
if coro is not None:
|
|
if mutex_mode:
|
|
if mutex_mode:
|
|
@@ -224,15 +235,19 @@ class CoroutineBasedSession(AbstractSession):
|
|
:param coro_obj: 协程对象
|
|
:param coro_obj: 协程对象
|
|
:return: An instance of `TaskHandle` is returned, which can be used later to close the task.
|
|
:return: An instance of `TaskHandle` is returned, which can be used later to close the task.
|
|
"""
|
|
"""
|
|
|
|
+ assert asyncio.iscoroutine(coro_obj), '`run_async()` only accept coroutine object'
|
|
|
|
+
|
|
self._alive_coro_cnt += 1
|
|
self._alive_coro_cnt += 1
|
|
|
|
|
|
task = Task(coro_obj, session=self, on_coro_stop=self._on_task_finish)
|
|
task = Task(coro_obj, session=self, on_coro_stop=self._on_task_finish)
|
|
self.coros[task.coro_id] = task
|
|
self.coros[task.coro_id] = task
|
|
- asyncio.get_event_loop().call_soon(task.step)
|
|
|
|
|
|
+ asyncio.get_event_loop().call_soon_threadsafe(task.step)
|
|
return task.task_handle()
|
|
return task.task_handle()
|
|
|
|
|
|
async def run_asyncio_coroutine(self, coro_obj):
|
|
async def run_asyncio_coroutine(self, coro_obj):
|
|
"""若会话线程和运行事件的线程不是同一个线程,需要用 asyncio_coroutine 来运行asyncio中的协程"""
|
|
"""若会话线程和运行事件的线程不是同一个线程,需要用 asyncio_coroutine 来运行asyncio中的协程"""
|
|
|
|
+ assert asyncio.iscoroutine(coro_obj), '`run_asyncio_coroutine()` only accept coroutine object'
|
|
|
|
+
|
|
res = await WebIOFuture(coro=coro_obj)
|
|
res = await WebIOFuture(coro=coro_obj)
|
|
return res
|
|
return res
|
|
|
|
|