1
0
Эх сурвалжийг харах

refactor session reconnect

wangweimin 3 жил өмнө
parent
commit
e675db6550

+ 0 - 0
pywebio/platform/adaptor/__init__.py


+ 199 - 0
pywebio/platform/adaptor/ws.py

@@ -0,0 +1,199 @@
+import asyncio
+import json
+import logging
+import time
+import typing
+from typing import Dict
+import abc
+from ..utils import deserialize_binary_event
+from ...session import CoroutineBasedSession, ThreadBasedSession, Session
+from ...utils import iscoroutinefunction, isgeneratorfunction, \
+    random_str, LRUDict
+
+logger = logging.getLogger(__name__)
+
+
+class _state:
+    # only used in reconnect enabled
+    # used to clean up session
+    detached_sessions = LRUDict()  # session_id -> detached timestamp. In increasing order of the time
+
+    # unclosed and unexpired session
+    # only used in reconnect enabled
+    # used to clean up session
+    # used to retrieve session by id when new connection
+    unclosed_sessions: Dict[str, Session] = {}  # session_id -> session
+
+    # the messages that can't deliver to browser when session close due to connection lost
+    undelivered_messages: Dict[str, list] = {}  # session_id -> unhandled message list
+
+    # used to get the active conn in session's callbacks
+    active_connections: Dict[str, 'WebSocketConnection'] = {}  # session_id -> WSHandler
+
+    expire_second = 10
+
+
+def set_expire_second(sec):
+    _state.expire_second = max(_state.expire_second, sec)
+
+
+def clean_expired_sessions():
+    while _state.detached_sessions:
+        session_id, detached_ts = _state.detached_sessions.popitem(last=False)  # 弹出最早过期的session
+
+        if time.time() < detached_ts + _state.expire_second:
+            # this session is not expired
+            _state.detached_sessions[session_id] = detached_ts  # restore
+            _state.detached_sessions.move_to_end(session_id, last=False)  # move to head
+            break
+
+        # clean this session
+        logger.debug("session %s expired" % session_id)
+        _state.active_connections.pop(session_id, None)
+        _state.undelivered_messages.pop(session_id, None)
+        session = _state.unclosed_sessions.pop(session_id, None)
+        if session:
+            session.close(nonblock=True)
+
+
+async def session_clean_task():
+    logger.debug("Start session cleaning task")
+    while True:
+        try:
+            clean_expired_sessions()
+        except Exception:
+            logger.exception("Error when clean expired sessions")
+
+        await asyncio.sleep(_state.expire_second // 2)
+
+
+class WebSocketConnection(abc.ABC):
+    @abc.abstractmethod
+    def get_query_argument(self, name) -> typing.Optional[str]:
+        pass
+
+    @abc.abstractmethod
+    def make_session_info(self) -> dict:
+        pass
+
+    @abc.abstractmethod
+    def write_message(self, message: dict):
+        pass
+
+    @abc.abstractmethod
+    def closed(self) -> bool:
+        return False
+
+    @abc.abstractmethod
+    def close(self):
+        pass
+
+
+class WebSocketHandler:
+    """
+    hold by one connection,
+    share one session with multiple connection in session lifetime, but one conn at a time
+    """
+
+    session_id: str = None
+    session: Session = None  # the session that current connection attaches
+    connection: WebSocketConnection
+    reconnectable: bool
+
+    def __init__(self, connection: WebSocketConnection, application, reconnectable: bool):
+        logger.debug("WebSocket opened")
+        self.connection = connection
+        self.reconnectable = reconnectable
+        self.session_id = connection.get_query_argument('session')
+
+        if self.session_id in ('NEW', None):  # 初始请求,创建新 Session
+            self._init_session(application)
+            if reconnectable:
+                # set session id to client, so the client can send it back to server to recover a session when it
+                # resumes form a connection lost
+                connection.write_message(dict(command='set_session_id', spec=self.session_id))
+        elif self.session_id not in _state.unclosed_sessions:  # session is expired
+            bye_msg = dict(command='close_session')
+            for m in _state.undelivered_messages.get(self.session_id, [bye_msg]):
+                try:
+                    connection.write_message(m)
+                except Exception:
+                    logger.exception("Error in sending message via websocket")
+        else:
+            self.session = _state.unclosed_sessions[self.session_id]
+            _state.detached_sessions.pop(self.session_id, None)
+            _state.active_connections[self.session_id] = connection
+            # send the latest messages to client
+            self._send_msg_to_client(self.session)
+
+        logger.debug('session id: %s' % self.session_id)
+
+    def _init_session(self, application):
+        session_info = self.connection.make_session_info()
+        self.session_id = random_str(24)
+        # todo: only set item when reconnection enabled
+        _state.active_connections[self.session_id] = self.connection
+
+        if iscoroutinefunction(application) or isgeneratorfunction(application):
+            self.session = CoroutineBasedSession(
+                application, session_info=session_info,
+                on_task_command=self._send_msg_to_client,
+                on_session_close=self._close_from_session)
+        else:
+            self.session = ThreadBasedSession(
+                application, session_info=session_info,
+                on_task_command=self._send_msg_to_client,
+                on_session_close=self._close_from_session,
+                loop=asyncio.get_event_loop())
+        _state.unclosed_sessions[self.session_id] = self.session
+
+    def _send_msg_to_client(self, session):
+        # self.connection may not be active,
+        # here we need the active connection for this session
+        conn = _state.active_connections.get(self.session_id)
+
+        if not conn or conn.closed():
+            return
+
+        for msg in session.get_task_commands():
+            try:
+                conn.write_message(msg)
+            except TypeError as e:
+                logger.exception('Data serialization error: %s\n'
+                                 'This may be because you pass the wrong type of parameter to the function'
+                                 ' of PyWebIO.\nData content: %s', e, msg)
+            except Exception:
+                logger.exception("Error in sending message via websocket")
+
+    def _close_from_session(self):
+        session = _state.unclosed_sessions[self.session_id]
+        if self.session_id in _state.active_connections:
+            # send the undelivered messages to client
+            self._send_msg_to_client(session=session)
+        else:
+            _state.undelivered_messages[self.session_id] = session.get_task_commands()
+
+        conn = _state.active_connections.pop(self.session_id, None)
+        _state.unclosed_sessions.pop(self.session_id, None)
+        if conn and not conn.closed():
+            conn.close()
+
+    def send_client_data(self, data):
+        if isinstance(data, bytes):
+            event = deserialize_binary_event(data)
+        else:
+            event = json.loads(data)
+        if event is None:
+            return
+        self.session.send_client_event(event)
+
+    def notify_connection_lost(self):
+        _state.active_connections.pop(self.session_id, None)
+        if not self.reconnectable:
+            # when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
+            # because the `session.close()` is reentrant
+            self.session.close(nonblock=True)
+        else:
+            if self.session_id in _state.unclosed_sessions:
+                _state.detached_sessions[self.session_id] = time.time()
+        logger.debug("WebSocket closed")

+ 51 - 134
pywebio/platform/tornado.py

@@ -4,27 +4,25 @@ import json
 import logging
 import os
 import threading
-import time
+import typing
 import webbrowser
 from functools import partial
-from typing import Dict
 from urllib.parse import urlparse
 
 import tornado
 import tornado.httpserver
 import tornado.ioloop
-from tornado.web import StaticFileHandler
-from tornado.websocket import WebSocketHandler
+import tornado.web
+import tornado.websocket
 
 from . import page
-from .remote_access import start_remote_access_service
+from .adaptor import ws as ws_adaptor
 from .page import make_applications, render_page
-from .utils import cdn_validation, deserialize_binary_event, print_listen_address
-from ..session import CoroutineBasedSession, ThreadBasedSession, ScriptModeSession, \
-    register_session_implement_for_target, Session
+from .remote_access import start_remote_access_service
+from .utils import cdn_validation, print_listen_address
+from ..session import ScriptModeSession, register_session_implement_for_target, Session
 from ..session.base import get_session_info_from_headers
-from ..utils import get_free_port, wait_host_port, STATIC_PATH, iscoroutinefunction, isgeneratorfunction, \
-    check_webio_js, parse_file_size, random_str, LRUDict
+from ..utils import get_free_port, wait_host_port, STATIC_PATH, check_webio_js, parse_file_size
 
 logger = logging.getLogger(__name__)
 
@@ -45,7 +43,7 @@ def ioloop() -> tornado.ioloop.IOLoop:
     return _ioloop
 
 
-def _check_origin(origin, allowed_origins, handler: WebSocketHandler):
+def _check_origin(origin, allowed_origins, handler: tornado.websocket.WebSocketHandler):
     if _is_same_site(origin, handler):
         return True
 
@@ -55,7 +53,7 @@ def _check_origin(origin, allowed_origins, handler: WebSocketHandler):
     )
 
 
-def _is_same_site(origin, handler: WebSocketHandler):
+def _is_same_site(origin, handler: tornado.websocket.WebSocketHandler):
     parsed_origin = urlparse(origin)
     origin = parsed_origin.netloc
     origin = origin.lower()
@@ -66,6 +64,32 @@ def _is_same_site(origin, handler: WebSocketHandler):
     return origin == host
 
 
+class WebSocketConnection(ws_adaptor.WebSocketConnection):
+
+    def __init__(self, context: tornado.websocket.WebSocketHandler):
+        self.context = context
+
+    def get_query_argument(self, name) -> typing.Optional[str]:
+        return self.context.get_query_argument(name, None)
+
+    def make_session_info(self) -> dict:
+        session_info = get_session_info_from_headers(self.context.request.headers)
+        session_info['user_ip'] = self.context.request.remote_ip
+        session_info['request'] = self.context.request
+        session_info['backend'] = 'tornado'
+        session_info['protocol'] = 'websocket'
+        return session_info
+
+    def write_message(self, message: dict):
+        self.context.write_message(json.dumps(message))
+
+    def closed(self) -> bool:
+        return not bool(self.context.ws_connection)
+
+    def close(self):
+        self.context.close()
+
+
 def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origin_func=_is_same_site):  # noqa: C901
     """
     :param dict applications: dict of `name -> task function`
@@ -78,16 +102,10 @@ def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origi
     if applications is None:
         applications = dict(index=lambda: None)  # mock PyWebIO app
 
-    class WSHandler(WebSocketHandler):
-        def __init__(self, *args, **kwargs):
-            super().__init__(*args, **kwargs)
-            self._close_from_session = False
-            self.session_id = None
-            self.session = None  # type: Session
-            if reconnect_timeout and not type(self)._started_clean_task:
-                type(self)._started_clean_task = True
-                tornado.ioloop.IOLoop.current().call_later(reconnect_timeout // 2, type(self).clean_expired_sessions)
-                logger.debug("Started session clean task")
+    ws_adaptor.set_expire_second(reconnect_timeout)
+    tornado.ioloop.IOLoop.current().spawn_callback(ws_adaptor.session_clean_task)
+
+    class Handler(tornado.websocket.WebSocketHandler):
 
         def get_app(self):
             app_name = self.get_query_argument('app', 'index')
@@ -100,7 +118,7 @@ def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origi
             return cdn
 
         async def get(self, *args, **kwargs) -> None:
-            # It's a simple http GET request
+            """http GET request"""
             if self.request.headers.get("Upgrade", "").lower() != "websocket":
                 # Backward compatible
                 # Frontend detect whether the backend is http server
@@ -108,7 +126,6 @@ def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origi
                     return self.write('')
 
                 app = self.get_app()
-
                 html = render_page(app, protocol='ws', cdn=self.get_cdn())
                 return self.write(html)
             else:
@@ -121,121 +138,21 @@ def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origi
             # Non-None enables compression with default options.
             return {}
 
-        @classmethod
-        def clean_expired_sessions(cls):
-            tornado.ioloop.IOLoop.current().call_later(reconnect_timeout // 2, cls.clean_expired_sessions)
-
-            while cls._session_expire:
-                session_id, expire_ts = cls._session_expire.popitem(last=False)  # 弹出最早过期的session
-
-                if time.time() < expire_ts:
-                    # this session is not expired
-                    cls._session_expire[session_id] = expire_ts  # restore this item
-                    cls._session_expire.move_to_end(session_id, last=False)  # move to front
-                    break
-
-                # clean this session
-                logger.debug("session %s expired" % session_id)
-                cls._connections.pop(session_id, None)
-                session = cls._webio_sessions.pop(session_id, None)
-                if session:
-                    session.close(nonblock=True)
-
-        @classmethod
-        def send_msg_to_client(cls, _, session_id=None):
-            conn = cls._connections.get(session_id)
-            session = cls._webio_sessions[session_id]
-
-            if not conn or not conn.ws_connection:
-                return
-
-            for msg in session.get_task_commands():
-                try:
-                    conn.write_message(json.dumps(msg))
-                except TypeError as e:
-                    logger.exception('Data serialization error: %s\n'
-                                     'This may be because you pass the wrong type of parameter to the function'
-                                     ' of PyWebIO.\nData content: %s', e, msg)
-
-        @classmethod
-        def close_from_session(cls, session_id=None):
-            cls.send_msg_to_client(None, session_id=session_id)
-
-            conn = cls._connections.pop(session_id, None)
-            cls._webio_sessions.pop(session_id, None)
-            if conn and conn.ws_connection:
-                conn._close_from_session = True
-                conn.close()
-
-        _started_clean_task = False
-        _session_expire = LRUDict()  # session_id -> expire timestamp. In increasing order of expire time
-        _webio_sessions = {}  # type: Dict[str, Session]  # session_id -> session
-        _connections = {}  # type: Dict[str, WSHandler]  # session_id -> WSHandler
+        _handler: ws_adaptor.WebSocketHandler
 
         def open(self):
-            logger.debug("WebSocket opened")
-            cls = type(self)
-
-            self.session_id = self.get_query_argument('session', None)
-            if self.session_id in ('NEW', None):  # 初始请求,创建新 Session
-                session_info = get_session_info_from_headers(self.request.headers)
-                session_info['user_ip'] = self.request.remote_ip
-                session_info['request'] = self.request
-                session_info['backend'] = 'tornado'
-                session_info['protocol'] = 'websocket'
-
-                application = self.get_app()
-                self.session_id = random_str(24)
-                cls._connections[self.session_id] = self
-
-                if iscoroutinefunction(application) or isgeneratorfunction(application):
-                    self.session = CoroutineBasedSession(
-                        application, session_info=session_info,
-                        on_task_command=partial(self.send_msg_to_client, session_id=self.session_id),
-                        on_session_close=partial(self.close_from_session, session_id=self.session_id))
-                else:
-                    self.session = ThreadBasedSession(
-                        application, session_info=session_info,
-                        on_task_command=partial(self.send_msg_to_client, session_id=self.session_id),
-                        on_session_close=partial(self.close_from_session, session_id=self.session_id),
-                        loop=asyncio.get_event_loop())
-                cls._webio_sessions[self.session_id] = self.session
-
-                if reconnect_timeout:
-                    self.write_message(json.dumps(dict(command='set_session_id', spec=self.session_id)))
-
-            elif self.session_id not in cls._webio_sessions:  # WebIOSession deleted
-                self.write_message(json.dumps(dict(command='close_session')))
-            else:
-                self.session = cls._webio_sessions[self.session_id]
-                cls._session_expire.pop(self.session_id, None)
-                cls._connections[self.session_id] = self
-                cls.send_msg_to_client(self.session, self.session_id)
-
-            logger.debug('session id: %s' % self.session_id)
+            conn = WebSocketConnection(self)
+            self._handler = ws_adaptor.WebSocketHandler(
+                connection=conn, application=self.get_app(), reconnectable=bool(reconnect_timeout)
+            )
 
         def on_message(self, message):
-            if isinstance(message, bytes):
-                event = deserialize_binary_event(message)
-            else:
-                event = json.loads(message)
-            if event is None:
-                return
-            self.session.send_client_event(event)
+            self._handler.send_client_data(message)
 
         def on_close(self):
-            cls = type(self)
-            cls._connections.pop(self.session_id, None)
-            if not reconnect_timeout and not self._close_from_session:
-                self.session.close(nonblock=True)
-            elif reconnect_timeout:
-                if self._close_from_session:
-                    cls._webio_sessions.pop(self.session_id, None)
-                elif self.session:
-                    cls._session_expire[self.session_id] = time.time() + reconnect_timeout
-            logger.debug("WebSocket closed")
+            self._handler.notify_connection_lost()
 
-    return WSHandler
+    return Handler
 
 
 def webio_handler(applications, cdn=True, reconnect_timeout=0, allowed_origins=None, check_origin=None):
@@ -278,9 +195,9 @@ def _setup_server(webio_handler, port=0, host='', static_dir=None, max_buffer_si
     handlers = [(r"/", webio_handler)]
 
     if static_dir is not None:
-        handlers.append((r"/static/(.*)", StaticFileHandler, {"path": static_dir}))
+        handlers.append((r"/static/(.*)", tornado.web.StaticFileHandler, {"path": static_dir}))
 
-    handlers.append((r"/(.*)", StaticFileHandler, {"path": STATIC_PATH, 'default_filename': 'index.html'}))
+    handlers.append((r"/(.*)", tornado.web.StaticFileHandler, {"path": STATIC_PATH, 'default_filename': 'index.html'}))
 
     app = tornado.web.Application(handlers=handlers, **tornado_app_settings)
     # Credit: https://stackoverflow.com/questions/19074972/content-length-too-long-when-uploading-file-using-tornado

+ 3 - 0
pywebio/session/threadbased.py

@@ -101,6 +101,9 @@ class ThreadBasedSession(Session):
                 except SessionException:  # ignore SessionException error
                     pass
                 finally:
+                    # we need first trigger close event and then perform close operation,
+                    # because close operation will clean up all resources in this session,
+                    # which may need to be accessed in close event
                     self._trigger_close_event()
                     self.close()
 

+ 2 - 1
setup.py

@@ -2,6 +2,7 @@ import os
 from functools import reduce
 
 from setuptools import setup
+from setuptools import find_namespace_packages
 
 here = os.path.abspath(os.path.dirname(__file__))
 
@@ -33,7 +34,7 @@ setup(
     url=about['__url__'],
     license=about['__license__'],
     python_requires=">=3.5.2",
-    packages=['pywebio', 'pywebio.session', 'pywebio.platform'],
+    packages=['pywebio', 'pywebio.session', 'pywebio.platform', 'pywebio.platform.adaptor'],
     scripts=['tools/pywebio-path-deploy'],
     package_data={
         # data files need to be listed both here (which determines what gets