123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- import abc
- import asyncio
- import json
- import logging
- import time
- import typing
- from typing import Dict, Optional
- from ...session import CoroutineBasedSession, Session, ThreadBasedSession
- from ...utils import LRUDict, iscoroutinefunction, isgeneratorfunction, random_str
- from ..utils import deserialize_binary_event
- logger = logging.getLogger(__name__)
- # used to store global state when reconnect enabled
- class _reconnect_state:
- # used to clean up session
- detached_sessions = LRUDict() # session_id -> detached timestamp. In increasing order of the time
- # unclosed and unexpired session
- # used to clean up session
- # used to retrieve session by id when new connection
- unclosed_sessions: Dict[str, Session] = {} # session_id -> session
- # the messages that can't deliver to browser when session close due to connection lost
- session_will_messages: Dict[str, list] = {} # session_id -> unhandled message list
- # used to get the active conn in session's callbacks
- active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler
- expire_second = 0
- def set_expire_second(sec):
- _reconnect_state.expire_second = max(_reconnect_state.expire_second, sec)
- def clean_expired_sessions():
- while _reconnect_state.detached_sessions:
- session_id, detached_ts = _reconnect_state.detached_sessions.popitem(last=False) # 弹出最早过期的session
- if time.time() < detached_ts + _reconnect_state.expire_second:
- # this session is not expired
- _reconnect_state.detached_sessions[session_id] = detached_ts # restore
- _reconnect_state.detached_sessions.move_to_end(session_id, last=False) # move to head
- break
- # clean this session
- logger.debug("session %s expired" % session_id)
- _reconnect_state.active_connections.pop(session_id, None)
- _reconnect_state.session_will_messages.pop(session_id, None)
- session = _reconnect_state.unclosed_sessions.pop(session_id, None)
- if session:
- session.close(nonblock=True)
- _session_clean_task_started = False
- async def session_clean_task():
- global _session_clean_task_started
- if _session_clean_task_started or not _reconnect_state.expire_second:
- return
- _session_clean_task_started = True
- logger.debug("Start session cleaning task")
- while True:
- try:
- clean_expired_sessions()
- except Exception:
- logger.exception("Error when clean expired sessions")
- await asyncio.sleep(_reconnect_state.expire_second // 2)
- class WebSocketConnection(abc.ABC):
- @abc.abstractmethod
- def get_query_argument(self, name) -> typing.Optional[str]:
- pass
- @abc.abstractmethod
- def make_session_info(self) -> dict:
- pass
- @abc.abstractmethod
- def write_message(self, message: dict):
- pass
- @abc.abstractmethod
- def closed(self) -> bool:
- return False
- @abc.abstractmethod
- def close(self):
- pass
- class WebSocketHandler:
- """
- hold by one connection,
- share one session with multiple connection in session lifetime, but one conn at a time
- """
- session_id: Optional[str] = None
- session: Optional[Session] = None # the session that current connection attaches
- connection: WebSocketConnection
- reconnectable: bool
- def __init__(self, connection: WebSocketConnection, application, reconnectable: bool, ioloop=None):
- logger.debug("WebSocket opened")
- self.connection = connection
- self.reconnectable = reconnectable
- self.session_id = connection.get_query_argument('session')
- self.ioloop = ioloop or asyncio.get_event_loop()
- if self.session_id in ('NEW', None): # 初始请求,创建新 Session
- self._init_session(application)
- if reconnectable:
- _reconnect_state.active_connections[self.session_id] = self.connection
- _reconnect_state.unclosed_sessions[self.session_id] = self.session
- # set session id to client, so the client can send it back to server to recover a session when it
- # resumes form a connection lost
- connection.write_message(dict(command='set_session_id', spec=self.session_id))
- elif self.session_id not in _reconnect_state.unclosed_sessions: # session is expired
- bye_msg = dict(command='close_session')
- for m in _reconnect_state.session_will_messages.get(self.session_id, [bye_msg]):
- try:
- connection.write_message(m)
- except Exception:
- logger.exception("Error in sending message via websocket")
- else: # resumes form a connection lost
- self.session = _reconnect_state.unclosed_sessions[self.session_id]
- _reconnect_state.detached_sessions.pop(self.session_id, None)
- _reconnect_state.active_connections[self.session_id] = connection
- # send the latest messages to client
- self._send_msg_to_client()
- logger.debug('session id: %s' % self.session_id)
- def _init_session(self, application):
- session_info = self.connection.make_session_info()
- self.session_id = random_str(24)
- if iscoroutinefunction(application) or isgeneratorfunction(application):
- self.session = CoroutineBasedSession(
- application, session_info=session_info,
- on_task_command=self._send_msg_to_client,
- on_session_close=self._close_from_session)
- else:
- self.session = ThreadBasedSession(
- application, session_info=session_info,
- on_task_command=self._send_msg_to_client,
- on_session_close=self._close_from_session,
- loop=self.ioloop)
- def _get_active_connection(self) -> Optional[WebSocketConnection]:
- # when reconnect enabled, the active connection for this session is in _reconnect_state.active_connections,
- # otherwise, it's self.connection.
- if self.reconnectable:
- conn = _reconnect_state.active_connections.get(self.session_id)
- else:
- conn = self.connection
- return conn
- def _send_msg_to_client(self, session: Session = None):
- conn = self._get_active_connection()
- session = session or self.session
- if not conn or conn.closed():
- return
- for msg in session.get_task_commands():
- try:
- conn.write_message(msg)
- except TypeError as e:
- logger.exception('Data serialization error: %s\n'
- 'This may be because you pass the wrong type of parameter to the function'
- ' of PyWebIO.\nData content: %s', e, msg)
- except Exception:
- logger.exception("Error in sending message via websocket")
- def _close_from_session(self):
- conn = self._get_active_connection()
- if conn and not conn.closed():
- self._send_msg_to_client()
- conn.close()
- elif self.reconnectable: # no active connection, and reconnect is enabled
- _reconnect_state.session_will_messages[self.session_id] = self.session.get_task_commands()
- self.session = None
- def send_client_data(self, data):
- if isinstance(data, bytes):
- event = deserialize_binary_event(data)
- else:
- event = json.loads(data)
- if event is None:
- return
- self.session.send_client_event(event)
- def notify_connection_lost(self):
- logger.debug("WebSocket closed")
- if not self.reconnectable and self.session:
- # when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
- # because the `session.close()` is reentrant
- self.session.close(nonblock=True)
- self.session = None # reset the reference
- return
- _reconnect_state.active_connections.pop(self.session_id, None)
- if self.session_id in _reconnect_state.unclosed_sessions:
- _reconnect_state.detached_sessions[self.session_id] = time.time()
|