浏览代码

add session reconnection to aiohttp

wangweimin 3 年之前
父节点
当前提交
953f836d88
共有 2 个文件被更改,包括 50 次插入46 次删除
  1. 3 2
      pywebio/platform/adaptor/ws.py
  2. 47 44
      pywebio/platform/aiohttp.py

+ 3 - 2
pywebio/platform/adaptor/ws.py

@@ -100,11 +100,12 @@ class WebSocketHandler:
     connection: WebSocketConnection
     reconnectable: bool
 
-    def __init__(self, connection: WebSocketConnection, application, reconnectable: bool):
+    def __init__(self, connection: WebSocketConnection, application, reconnectable: bool, ioloop=None):
         logger.debug("WebSocket opened")
         self.connection = connection
         self.reconnectable = reconnectable
         self.session_id = connection.get_query_argument('session')
+        self.ioloop = ioloop or asyncio.get_event_loop()
 
         if self.session_id in ('NEW', None):  # 初始请求,创建新 Session
             self._init_session(application)
@@ -144,7 +145,7 @@ class WebSocketHandler:
                 application, session_info=session_info,
                 on_task_command=self._send_msg_to_client,
                 on_session_close=self._close_from_session,
-                loop=asyncio.get_event_loop())
+                loop=self.ioloop)
         _state.unclosed_sessions[self.session_id] = self.session
 
     def _send_msg_to_client(self, session):

+ 47 - 44
pywebio/platform/aiohttp.py

@@ -3,18 +3,20 @@ import fnmatch
 import json
 import logging
 import os
+import typing
 from functools import partial
 from urllib.parse import urlparse
 
 from aiohttp import web
 
+from .adaptor import ws as ws_adaptor
 from .page import make_applications, render_page
 from .remote_access import start_remote_access_service
 from .tornado import open_webbrowser_on_server_started
-from .utils import cdn_validation, deserialize_binary_event, print_listen_address
-from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, Session
+from .utils import cdn_validation, print_listen_address
+from ..session import register_session_implement_for_target, Session
 from ..session.base import get_session_info_from_headers
-from ..utils import get_free_port, STATIC_PATH, iscoroutinefunction, isgeneratorfunction
+from ..utils import get_free_port, STATIC_PATH
 
 logger = logging.getLogger(__name__)
 
@@ -39,7 +41,36 @@ def _is_same_site(origin, host):
     return origin == host
 
 
-def _webio_handler(applications, cdn, websocket_settings, check_origin_func=_is_same_site):
+class WebSocketConnection(ws_adaptor.WebSocketConnection):
+
+    def __init__(self, ws: web.WebSocketResponse, http: web.Request, ioloop):
+        self.ws = ws
+        self.http = http
+        self.ioloop = ioloop
+
+    def get_query_argument(self, name) -> typing.Optional[str]:
+        return self.http.query.getone(name, None)
+
+    def make_session_info(self) -> dict:
+        session_info = get_session_info_from_headers(self.http.headers)
+        session_info['user_ip'] = self.http.remote
+        session_info['request'] = self.http
+        session_info['backend'] = 'aiohttp'
+        session_info['protocol'] = 'websocket'
+        return session_info
+
+    def write_message(self, message: dict):
+        msg_str = json.dumps(message)
+        self.ioloop.create_task(self.ws.send_str(msg_str))
+
+    def closed(self) -> bool:
+        return self.ws.closed
+
+    def close(self):
+        self.ioloop.create_task(self.ws.close())
+
+
+def _webio_handler(applications, cdn, websocket_settings, reconnect_timeout=0, check_origin_func=_is_same_site):
     """
     :param dict applications: dict of `name -> task function`
     :param bool/str cdn: Whether to load front-end static resources from CDN
@@ -68,61 +99,31 @@ def _webio_handler(applications, cdn, websocket_settings, check_origin_func=_is_
         ws = web.WebSocketResponse(**websocket_settings)
         await ws.prepare(request)
 
-        close_from_session_tag = False  # 是否由session主动关闭连接
-
-        def send_msg_to_client(session: Session):
-            for msg in session.get_task_commands():
-                msg_str = json.dumps(msg)
-                ioloop.create_task(ws.send_str(msg_str))
-
-        def close_from_session():
-            nonlocal close_from_session_tag
-            close_from_session_tag = True
-            ioloop.create_task(ws.close())
-            logger.debug("WebSocket closed from session")
-
-        session_info = get_session_info_from_headers(request.headers)
-        session_info['user_ip'] = request.remote
-        session_info['request'] = request
-        session_info['backend'] = 'aiohttp'
-        session_info['protocol'] = 'websocket'
-
         app_name = request.query.getone('app', 'index')
         application = applications.get(app_name) or applications['index']
 
-        if iscoroutinefunction(application) or isgeneratorfunction(application):
-            session = CoroutineBasedSession(application, session_info=session_info,
-                                            on_task_command=send_msg_to_client,
-                                            on_session_close=close_from_session)
-        else:
-            session = ThreadBasedSession(application, session_info=session_info,
-                                         on_task_command=send_msg_to_client,
-                                         on_session_close=close_from_session, loop=ioloop)
+        conn = WebSocketConnection(ws, request, ioloop)
+        handler = ws_adaptor.WebSocketHandler(
+            connection=conn, application=application, reconnectable=bool(reconnect_timeout), ioloop=ioloop
+        )
 
         # see: https://github.com/aio-libs/aiohttp/issues/1768
         try:
             async for msg in ws:
-                if msg.type == web.WSMsgType.text:
-                    data = msg.json()
-                elif msg.type == web.WSMsgType.binary:
-                    data = deserialize_binary_event(msg.data)
+                if msg.type in (web.WSMsgType.text, web.WSMsgType.binary):
+                    handler.send_client_data(msg.data)
                 elif msg.type == web.WSMsgType.close:
                     raise asyncio.CancelledError()
-
-                if data is not None:
-                    session.send_client_event(data)
         finally:
-            if not close_from_session_tag:
-                # close session because client disconnected to server
-                session.close(nonblock=True)
-                logger.debug("WebSocket closed from client")
+            handler.notify_connection_lost()
 
         return ws
 
     return wshandle
 
 
-def webio_handler(applications, cdn=True, allowed_origins=None, check_origin=None, websocket_settings=None):
+def webio_handler(applications, cdn=True, reconnect_timeout=0, allowed_origins=None, check_origin=None,
+                  websocket_settings=None):
     """Get the `Request Handler <https://docs.aiohttp.org/en/stable/web_quickstart.html#aiohttp-web-handler>`_ coroutine for running PyWebIO applications in aiohttp.
     The handler communicates with the browser by WebSocket protocol.
 
@@ -145,6 +146,7 @@ def webio_handler(applications, cdn=True, allowed_origins=None, check_origin=Non
 
     return _webio_handler(applications=applications, cdn=cdn,
                           check_origin_func=check_origin_func,
+                          reconnect_timeout=reconnect_timeout,
                           websocket_settings=websocket_settings)
 
 
@@ -168,6 +170,7 @@ def static_routes(prefix='/'):
 
 def start_server(applications, port=0, host='', debug=False,
                  cdn=True, static_dir=None, remote_access=False,
+                 reconnect_timeout=0,
                  allowed_origins=None, check_origin=None,
                  auto_open_webbrowser=False,
                  websocket_settings=None,
@@ -191,7 +194,7 @@ def start_server(applications, port=0, host='', debug=False,
 
     cdn = cdn_validation(cdn, 'warn')
 
-    handler = webio_handler(applications, cdn=cdn, allowed_origins=allowed_origins,
+    handler = webio_handler(applications, cdn=cdn, allowed_origins=allowed_origins, reconnect_timeout=reconnect_timeout,
                             check_origin=check_origin, websocket_settings=websocket_settings)
 
     app = web.Application(**aiohttp_settings)