Pārlūkot izejas kodu

refine asyncio based session

wangweimin 5 gadi atpakaļ
vecāks
revīzija
0d03820d52

+ 7 - 8
pywebio/io_ctrl.py

@@ -2,20 +2,19 @@ import asyncio
 import inspect
 import logging
 
-from .framework import Global, Task
-from .framework import WebIOFuture
+from .session.asyncbased import WebIOFuture, AsyncBasedSession, Task
 from .ioloop import run_async
 
 logger = logging.getLogger(__name__)
 
 
 def send_msg(cmd, spec=None):
-    msg = dict(command=cmd, spec=spec, coro_id=Global.active_coro_id)
-    Global.active_ws.send_coro_msg(msg)
+    msg = dict(command=cmd, spec=spec, coro_id=AsyncBasedSession.get_current_task_id())
+    AsyncBasedSession.get_current_session().send_task_message(msg)
 
 
 async def next_event():
-    res = await WebIOFuture()
+    res = await AsyncBasedSession.get_current_session().next_client_event()
     return res
 
 
@@ -145,7 +144,7 @@ def output_register_callback(callback, save, mutex_mode):
                 try:
                     callback(event['data'], save)
                 except:
-                    Global.active_ws.on_coro_error()
+                    AsyncBasedSession.get_current_session().on_task_exception()
 
             if coro is not None:
                 if mutex_mode:
@@ -153,8 +152,8 @@ def output_register_callback(callback, save, mutex_mode):
                 else:
                     run_async(coro)
 
-    callback_task = Task(callback_coro(), Global.active_ws)
+    callback_task = Task(callback_coro(), AsyncBasedSession.get_current_session())
     callback_task.coro.send(None)  # 激活,Non't callback.step() ,导致嵌套调用step  todo 与inactive_coro_instances整合
-    Global.active_ws.coros[callback_task.coro_id] = callback_task
+    AsyncBasedSession.get_current_session().coros[callback_task.coro_id] = callback_task
 
     return callback_task.coro_id

+ 2 - 2
pywebio/ioloop.py

@@ -1,12 +1,12 @@
 import tornado.websocket
 from tornado.web import StaticFileHandler
-from .framework import Global
+from .session.asyncbased import AsyncBasedSession
 from .platform import STATIC_PATH
 from .platform.tornado import webio_handler
 
 
 def run_async(coro_obj):
-    Global.active_ws.inactive_coro_instances.append(coro_obj)
+    AsyncBasedSession.get_current_session().run_async(coro_obj)
 
 
 def start_ioloop(coro_func, port=8080, debug=True, tornado_app_args=None):

+ 0 - 1
pywebio/output.py

@@ -34,7 +34,6 @@ r"""输出内容到用户浏览器
 from base64 import b64encode
 from collections.abc import Mapping
 
-from .framework import Global
 from .io_ctrl import output_register_callback, send_msg
 
 

+ 20 - 9
pywebio/platform/flask.py

@@ -1,6 +1,19 @@
 """
 Flask backend
 
+.. attention::
+    PyWebIO 的会话状态保存在进程内,所以不支持多进程部署的Flask。
+        比如使用 ``uWSGI`` 部署Flask,并使用 ``--processes n`` 选项设置了多进程;
+        或者使用 ``nginx`` 等反向代理将流量负载到多个 Flask 副本上。
+
+    A note on run Flask with uWSGI:
+
+    If you start uWSGI without threads, the Python GIL will not be enabled,
+    so threads generated by your application will never run. `uWSGI doc <https://uwsgi-docs.readthedocs.io/en/latest/WSGIquickstart.html#a-note-on-python-threads>`_
+    在Flask backend中,PyWebIO使用单独一个线程来运行事件循环。如果程序中没有使用到asyncio中的协程函数,
+    可以在 start_flask_server 参数中设置 ``disable_asyncio=False`` 来关闭对asyncio协程函数的支持。
+    如果您需要使用asyncio协程函数,那么需要在在uWSGI中使用 ``--enable-thread`` 选项开启线程支持。
+
 """
 import asyncio
 import threading
@@ -11,11 +24,11 @@ from typing import Dict
 from flask import Flask, request, jsonify, send_from_directory
 
 from . import STATIC_PATH
-from ..framework import WebIOSession
+from ..session import AsyncBasedSession
 from ..utils import random_str, LRUDict
 
-# todo Flask 的线程模型是否会造成竞争条件?
-_webio_sessions: Dict[str, WebIOSession] = {}  # WebIOSessionID -> WebIOSession()
+# todo: use lock to avoid thread race condition
+_webio_sessions: Dict[str, AsyncBasedSession] = {}  # WebIOSessionID -> WebIOSession()
 _webio_expire = LRUDict()  # WebIOSessionID -> last active timestamp
 
 DEFAULT_SESSION_EXPIRE_SECONDS = 60 * 60 * 4  # 超过4个小时会话不活跃则视为会话过期
@@ -24,10 +37,8 @@ REMOVE_EXPIRED_SESSIONS_INTERVAL = 120  # 清理过期会话间隔(秒)
 _event_loop = None
 
 
-def _make_response(webio_session: WebIOSession):
-    res = webio_session.unhandled_server_msgs
-    webio_session.unhandled_server_msgs = []
-    return jsonify(res)
+def _make_response(webio_session: AsyncBasedSession):
+    return jsonify(webio_session.get_task_messages())
 
 
 def _remove_expired_sessions(session_expire_seconds):
@@ -65,7 +76,7 @@ def _webio_view(coro_func, session_expire_seconds):
     webio_session_id = None
     if 'webio_session_id' not in request.cookies:  # start new WebIOSession
         webio_session_id = random_str(24)
-        webio_session = WebIOSession(coro_func)
+        webio_session = AsyncBasedSession(coro_func)
         _webio_sessions[webio_session_id] = webio_session
         _webio_expire[webio_session_id] = time.time()
     elif request.cookies['webio_session_id'] not in _webio_sessions:  # WebIOSession deleted
@@ -75,7 +86,7 @@ def _webio_view(coro_func, session_expire_seconds):
         webio_session = _webio_sessions[webio_session_id]
 
     if request.method == 'POST':  # client push event
-        webio_session.send_client_msg(request.json)
+        webio_session.send_client_event(request.json)
 
     elif request.method == 'GET':  # client pull messages
         pass

+ 6 - 6
pywebio/platform/tornado.py

@@ -2,7 +2,7 @@ import json
 
 import tornado
 import tornado.websocket
-from ..framework import WebIOSession
+from ..session import AsyncBasedSession
 
 
 def webio_handler(coro_func, debug=True):
@@ -15,21 +15,21 @@ def webio_handler(coro_func, debug=True):
             # Non-None enables compression with default options.
             return {}
 
-        def on_coro_msg(self, controller):
-            while controller.unhandled_server_msgs:
-                msg = controller.unhandled_server_msgs.pop()
+        def send_msg_to_client(self, controller: AsyncBasedSession):
+            for msg in controller.get_task_messages():
                 self.write_message(json.dumps(msg))
 
         def open(self):
             print("WebSocket opened")
             self.set_nodelay(True)
 
-            self.controller = WebIOSession(coro_func, on_coro_msg=self.on_coro_msg, on_session_close=self.close)
+            self.controller = AsyncBasedSession(coro_func, on_task_message=self.send_msg_to_client,
+                                                on_session_close=self.close)
 
         def on_message(self, message):
             # print('on_message', message)
             data = json.loads(message)
-            self.controller.send_client_msg(data)
+            self.controller.send_client_event(data)
 
         def on_close(self):
             self.controller.close(no_session_close_callback=True)

+ 1 - 0
pywebio/session/__init__.py

@@ -0,0 +1 @@
+from .asyncbased import AsyncBasedSession

+ 76 - 48
pywebio/framework.py → pywebio/session/asyncbased.py

@@ -3,7 +3,8 @@ import sys
 import traceback
 from contextlib import contextmanager
 import asyncio
-from .utils import random_str
+from ..utils import random_str
+from .base import AbstractSession
 
 logger = logging.getLogger(__name__)
 
@@ -19,7 +20,12 @@ class WebIOFuture:
     __await__ = __iter__  # make compatible with 'await' expression
 
 
-class WebIOSession:
+class _context:
+    current_session = None  # type:"AsyncBasedSession"
+    current_task_id = None
+
+
+class AsyncBasedSession(AbstractSession):
     """
     一个PyWebIO任务会话, 由不同的后端Backend创建并维护
 
@@ -31,79 +37,91 @@ class WebIOSession:
         后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
     """
 
-    def __init__(self, coro_func, on_coro_msg=None, on_session_close=None):
+    @staticmethod
+    def get_current_session() -> "AsyncBasedSession":
+        if _context.current_session is None:
+            raise RuntimeError("No current found in context!")
+        return _context.current_session
+
+    @staticmethod
+    def get_current_task_id():
+        if _context.current_task_id is None:
+            raise RuntimeError("No current task found in context!")
+        return _context.current_task_id
+
+    def __init__(self, coroutine_func, on_task_message=None, on_session_close=None):
         """
         :param coro_func: 协程函数
         :param on_coro_msg: 由协程内发给session的消息的处理函数
         :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
         """
-        self._on_coro_msg = on_coro_msg or (lambda _: None)
+        self._on_task_message = on_task_message or (lambda _: None)
         self._on_session_close = on_session_close or (lambda: None)
-        self.unhandled_server_msgs = []
+        self.unhandled_task_msgs = []
 
         self.coros = {}  # coro_id -> coro
 
         self._closed = False
         self.inactive_coro_instances = []  # 待激活的协程实例列表
 
-        self.main_task = Task(coro_func(), ws=self)
+        self.main_task = Task(coroutine_func(), session=self, on_coro_stop=self._on_main_task_finish)
         self.coros[self.main_task.coro_id] = self.main_task
 
         self._step_task(self.main_task)
 
     def _step_task(self, task, result=None):
         task.step(result)
-        if task.task_finished:
+        if task.task_finished and task.coro_id in self.coros:
+            # 若task 为main task,则 task.step(result) 结束后,可能task已经结束,self.coros已被清理
             logger.debug('del self.coros[%s]', task.coro_id)
             del self.coros[task.coro_id]
 
-        while self.inactive_coro_instances:
+        while self.inactive_coro_instances and not self.main_task.task_finished:
             coro = self.inactive_coro_instances.pop()
-            sub_task = Task(coro, ws=self)
+            sub_task = Task(coro, session=self)
             self.coros[sub_task.coro_id] = sub_task
             sub_task.step()
             if sub_task.task_finished:
                 logger.debug('del self.coros[%s]', sub_task.coro_id)
                 del self.coros[sub_task.coro_id]
 
-        if self.main_task.task_finished:
-            self.send_coro_msg(dict(command='close_session'))
-            self.close()
+    def _on_main_task_finish(self):
+        self.send_task_message(dict(command='close_session'))
+        self.close()
 
-    def send_coro_msg(self, message):
+    def send_task_message(self, message):
         """向会话发送来自协程内的消息
 
         :param dict message: 消息
         """
-        self.unhandled_server_msgs.append(message)
-        self._on_coro_msg(self)
+        self.unhandled_task_msgs.append(message)
+        self._on_task_message(self)
 
-    def send_client_msg(self, message):
+    async def next_client_event(self):
+        res = await WebIOFuture()
+        return res
+
+    def send_client_event(self, event):
         """向会话发送来自用户浏览器的事件️
 
-        :param dict message: 事件️消息
+        :param dict event: 事件️消息
         """
-        # data = json.loads(message)
-        coro_id = message['coro_id']
+        coro_id = event['coro_id']
         coro = self.coros.get(coro_id)
         if not coro:
             logger.error('coro not found, coro_id:%s', coro_id)
             return
 
-        self._step_task(coro, message)
+        self._step_task(coro, event)
 
-    def on_coro_error(self):
-        from .output import put_markdown  # todo
-        logger.exception('Error in coroutine executing')
-        type, value, tb = sys.exc_info()
-        tb_len = len(list(traceback.walk_tb(tb)))
-        lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
-        traceback_msg = ''.join(lines)
-        put_markdown("发生错误:\n```\n%s\n```" % traceback_msg)
+    def get_task_messages(self):
+        msgs = self.unhandled_task_msgs
+        self.unhandled_task_msgs = []
+        return msgs
 
     def _cleanup(self):
         for t in self.coros.values():
-            t.cancel()
+            t.close()
         self.coros = {}  # delete session tasks
 
         while self.inactive_coro_instances:
@@ -125,21 +143,35 @@ class WebIOSession:
     def closed(self):
         return self._closed
 
+    def on_task_exception(self):
+        from ..output import put_markdown  # todo
+        logger.exception('Error in coroutine executing')
+        type, value, tb = sys.exc_info()
+        tb_len = len(list(traceback.walk_tb(tb)))
+        lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
+        traceback_msg = ''.join(lines)
+        put_markdown("发生错误:\n```\n%s\n```" % traceback_msg)
+
+    def run_async(self, coro_obj):
+        self.inactive_coro_instances.append(coro_obj)
+
 
 class Task:
     @contextmanager
-    def ws_context(self):
+    def session_context(self):
         """
-        >>> with ws_context():
+        >>> with session_context():
         ...     res = self.coros[-1].send(data)
         """
-        Global.active_ws = self.ws
-        Global.active_coro_id = self.coro_id
+
+        # todo issue: with 语句可能发生嵌套,导致内层with退出时,将属性置空
+        _context.current_session = self.session
+        _context.current_task_id = self.coro_id
         try:
             yield
         finally:
-            Global.active_ws = None
-            Global.active_coro_id = None
+            _context.current_session = None
+            _context.current_task_id = None
 
     @staticmethod
     def gen_coro_id(coro=None):
@@ -149,12 +181,13 @@ class Task:
 
         return '%s-%s' % (name, random_str(10))
 
-    def __init__(self, coro, ws):
-        self.ws = ws
+    def __init__(self, coro, session: AsyncBasedSession, on_coro_stop=None):
+        self.session = session
         self.coro = coro
         self.coro_id = None
         self.result = None
         self.task_finished = False  # 任务完毕/取消
+        self.on_coro_stop = on_coro_stop or (lambda: None)
 
         self.coro_id = self.gen_coro_id(self.coro)
 
@@ -164,7 +197,7 @@ class Task:
 
     def step(self, result=None):
         coro_yield = None
-        with self.ws_context():
+        with self.session_context():
             try:
                 coro_yield = self.coro.send(result)
             except StopIteration as e:
@@ -172,8 +205,9 @@ class Task:
                     self.result = e.args[0]
                 self.task_finished = True
                 logger.debug('Task[%s] finished', self.coro_id)
+                self.on_coro_stop()
             except Exception as e:
-                self.ws.on_coro_error()
+                self.session.on_task_exception()
 
         future = None
         if isinstance(coro_yield, WebIOFuture):
@@ -181,7 +215,7 @@ class Task:
                 future = asyncio.run_coroutine_threadsafe(coro_yield.coro, asyncio.get_event_loop())
         elif coro_yield is not None:
             future = coro_yield
-        if not self.ws.closed() and hasattr(future, 'add_done_callback'):
+        if not self.session.closed() and hasattr(future, 'add_done_callback'):
             future.add_done_callback(self._tornado_future_callback)
             self.pending_futures[id(future)] = future
 
@@ -190,8 +224,8 @@ class Task:
             del self.pending_futures[id(future)]
             self.step(future.result())
 
-    def cancel(self):
-        logger.debug('Task[%s] canceled', self.coro_id)
+    def close(self):
+        logger.debug('Task[%s] closed', self.coro_id)
         self.coro.close()
         while self.pending_futures:
             _, f = self.pending_futures.popitem()
@@ -202,9 +236,3 @@ class Task:
     def __del__(self):
         if not self.task_finished:
             logger.warning('Task[%s] not finished when destroy', self.coro_id)
-
-
-class Global:
-    # todo issue: with 语句可能发生嵌套,导致内层with退出时,将属性置空
-    active_ws = None  # type:"WebIOController"
-    active_coro_id = None

+ 54 - 0
pywebio/session/base.py

@@ -0,0 +1,54 @@
+class AbstractSession:
+    """
+    由Task在当前Session上下文中调用:
+        get_current_session
+        get_current_task_id
+
+        send_task_message
+        next_client_event
+        on_task_exception
+
+    由Backend调用:
+        send_client_event
+        get_task_messages
+
+    Task和Backend都可调用:
+        close
+        closed
+
+    .. note::
+        后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
+    """
+
+    @staticmethod
+    def get_current_session() -> "AbstractSession":
+        raise NotImplementedError
+
+    @staticmethod
+    def get_current_task_id():
+        raise NotImplementedError
+
+    def __init__(self, target, on_task_message=None, on_session_close=None, **kwargs):
+        raise NotImplementedError
+
+    def send_task_message(self, message):
+        raise NotImplementedError
+
+    def next_client_event(self):
+        raise NotImplementedError
+
+    def send_client_event(self, event):
+        raise NotImplementedError
+
+    def get_task_messages(self):
+        raise NotImplementedError
+
+    def close(self):
+        raise NotImplementedError
+
+    def closed(self):
+        raise NotImplementedError
+
+    def on_task_exception(self):
+        raise NotImplementedError
+