Explorar o código

fix: error when use flask backend with coroutine based session

wangweimin %!s(int64=5) %!d(string=hai) anos
pai
achega
148ebd1e2e
Modificáronse 2 ficheiros con 31 adicións e 12 borrados
  1. 4 0
      pywebio/platform/flask.py
  2. 27 12
      pywebio/session/coroutinebased.py

+ 4 - 0
pywebio/platform/flask.py

@@ -207,6 +207,7 @@ def run_event_loop(debug=False):
        See also: https://docs.python.org/3/library/asyncio-dev.html#asyncio-debug-mode
     """
     global _event_loop
+    CoroutineBasedSession.event_loop_thread_id = threading.current_thread().ident
     _event_loop = asyncio.new_event_loop()
     _event_loop.set_debug(debug)
     asyncio.set_event_loop(_event_loop)
@@ -262,4 +263,7 @@ def start_server(target, port=8080, host='localhost',
     if not disable_asyncio and get_session_implement() is CoroutineBasedSession:
         threading.Thread(target=run_event_loop, daemon=True).start()
 
+    if not debug:
+        logging.getLogger('werkzeug').setLevel(logging.WARNING)
+
     app.run(host=host, port=port, debug=debug, **flask_options)

+ 27 - 12
pywebio/session/coroutinebased.py

@@ -4,7 +4,7 @@ import sys
 import threading
 import traceback
 from contextlib import contextmanager
-
+from functools import partial
 from .base import AbstractSession
 from ..exceptions import SessionNotFoundException, SessionClosedException, SessionException
 from ..utils import random_str, isgeneratorfunction, iscoroutinefunction, catch_exp_call
@@ -34,18 +34,24 @@ class CoroutineBasedSession(AbstractSession):
 
     当主协程任务和会话内所有通过 `run_async` 注册的协程都退出后,会话关闭。
     当用户浏览器主动关闭会话,CoroutineBasedSession.close 被调用, 协程任务和会话内所有通过 `run_async` 注册的协程都被关闭。
+
     """
 
+    # 运行事件循环的线程id
+    # 用于在 CoroutineBasedSession.get_current_session() 判断调用方是否合法
+    # Tornado backend时,在创建第一个CoroutineBasedSession时初始化
+    # Flask backend时,在platform.flaskrun_event_loop()时初始化
+    event_loop_thread_id = None
+
     _active_session_cnt = 0
 
     @classmethod
     def active_session_count(cls):
         return cls._active_session_cnt
 
-    @staticmethod
-    def get_current_session() -> "CoroutineBasedSession":
-        if _context.current_session is None or \
-                _context.current_session.session_thread_id != threading.current_thread().ident:
+    @classmethod
+    def get_current_session(cls) -> "CoroutineBasedSession":
+        if _context.current_session is None or cls.event_loop_thread_id != threading.current_thread().ident:
             raise SessionNotFoundException("No session found in current context!")
 
         if _context.current_session.closed():
@@ -79,8 +85,10 @@ class CoroutineBasedSession(AbstractSession):
         # 当前会话未被Backend处理的消息
         self.unhandled_task_msgs = []
 
-        # 创建会话的线程id。当前会话只能在本线程中使用
-        self.session_thread_id = threading.current_thread().ident
+        # 在创建第一个CoroutineBasedSession时 event_loop_thread_id 还未被初始化
+        # 则当前线程即为运行 event loop 的线程
+        if CoroutineBasedSession.event_loop_thread_id is None:
+            CoroutineBasedSession.event_loop_thread_id = threading.current_thread().ident
 
         # 会话内的协程任务
         self.coros = {}  # coro_task_id -> Task()
@@ -96,7 +104,7 @@ class CoroutineBasedSession(AbstractSession):
         self._step_task(main_task)
 
     def _step_task(self, task, result=None):
-        task.step(result)
+        asyncio.get_event_loop().call_soon_threadsafe(partial(task.step, result))
 
     def _on_task_finish(self, task: "Task"):
         self._alive_coro_cnt -= 1
@@ -140,7 +148,6 @@ class CoroutineBasedSession(AbstractSession):
         if not coro:
             logger.error('coro not found, coro_id:%s', coro_id)
             return
-
         self._step_task(coro, event)
 
     def get_task_commands(self):
@@ -193,7 +200,11 @@ class CoroutineBasedSession(AbstractSession):
 
         async def callback_coro():
             while True:
-                event = await self.next_client_event()
+                try:
+                    event = await self.next_client_event()
+                except SessionClosedException:
+                    return
+
                 assert event['event'] == 'callback'
                 coro = None
                 if iscoroutinefunction(callback):
@@ -204,7 +215,7 @@ class CoroutineBasedSession(AbstractSession):
                     try:
                         callback(event['data'])
                     except:
-                        CoroutineBasedSession.get_current_session().on_task_exception()
+                        self.on_task_exception()
 
                 if coro is not None:
                     if mutex_mode:
@@ -224,15 +235,19 @@ class CoroutineBasedSession(AbstractSession):
         :param coro_obj: 协程对象
         :return: An instance of  `TaskHandle` is returned, which can be used later to close the task.
         """
+        assert asyncio.iscoroutine(coro_obj), '`run_async()` only accept coroutine object'
+
         self._alive_coro_cnt += 1
 
         task = Task(coro_obj, session=self, on_coro_stop=self._on_task_finish)
         self.coros[task.coro_id] = task
-        asyncio.get_event_loop().call_soon(task.step)
+        asyncio.get_event_loop().call_soon_threadsafe(task.step)
         return task.task_handle()
 
     async def run_asyncio_coroutine(self, coro_obj):
         """若会话线程和运行事件的线程不是同一个线程,需要用 asyncio_coroutine 来运行asyncio中的协程"""
+        assert asyncio.iscoroutine(coro_obj), '`run_asyncio_coroutine()` only accept coroutine object'
+
         res = await WebIOFuture(coro=coro_obj)
         return res