ws.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import abc
  2. import asyncio
  3. import json
  4. import logging
  5. import time
  6. import typing
  7. from typing import Dict, Optional
  8. from ...session import CoroutineBasedSession, Session, ThreadBasedSession
  9. from ...utils import LRUDict, iscoroutinefunction, isgeneratorfunction, random_str
  10. from ..utils import deserialize_binary_event
  11. logger = logging.getLogger(__name__)
  12. # used to store global state when reconnect enabled
  13. class _reconnect_state:
  14. # used to clean up session
  15. detached_sessions = LRUDict() # session_id -> detached timestamp. In increasing order of the time
  16. # unclosed and unexpired session
  17. # used to clean up session
  18. # used to retrieve session by id when new connection
  19. unclosed_sessions: Dict[str, Session] = {} # session_id -> session
  20. # the messages that can't deliver to browser when session close due to connection lost
  21. session_will_messages: Dict[str, list] = {} # session_id -> unhandled message list
  22. # used to get the active conn in session's callbacks
  23. active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler
  24. expire_second = 0
  25. def set_expire_second(sec):
  26. _reconnect_state.expire_second = max(_reconnect_state.expire_second, sec)
  27. def clean_expired_sessions():
  28. while _reconnect_state.detached_sessions:
  29. session_id, detached_ts = _reconnect_state.detached_sessions.popitem(last=False) # 弹出最早过期的session
  30. if time.time() < detached_ts + _reconnect_state.expire_second:
  31. # this session is not expired
  32. _reconnect_state.detached_sessions[session_id] = detached_ts # restore
  33. _reconnect_state.detached_sessions.move_to_end(session_id, last=False) # move to head
  34. break
  35. # clean this session
  36. logger.debug("session %s expired" % session_id)
  37. _reconnect_state.active_connections.pop(session_id, None)
  38. _reconnect_state.session_will_messages.pop(session_id, None)
  39. session = _reconnect_state.unclosed_sessions.pop(session_id, None)
  40. if session:
  41. session.close(nonblock=True)
  42. _session_clean_task_started = False
  43. async def session_clean_task():
  44. global _session_clean_task_started
  45. if _session_clean_task_started or not _reconnect_state.expire_second:
  46. return
  47. _session_clean_task_started = True
  48. logger.debug("Start session cleaning task")
  49. while True:
  50. try:
  51. clean_expired_sessions()
  52. except Exception:
  53. logger.exception("Error when clean expired sessions")
  54. await asyncio.sleep(_reconnect_state.expire_second // 2)
  55. class WebSocketConnection(abc.ABC):
  56. @abc.abstractmethod
  57. def get_query_argument(self, name) -> typing.Optional[str]:
  58. pass
  59. @abc.abstractmethod
  60. def make_session_info(self) -> dict:
  61. pass
  62. @abc.abstractmethod
  63. def write_message(self, message: dict):
  64. pass
  65. @abc.abstractmethod
  66. def closed(self) -> bool:
  67. return False
  68. @abc.abstractmethod
  69. def close(self):
  70. pass
  71. class WebSocketHandler:
  72. """
  73. hold by one connection,
  74. share one session with multiple connection in session lifetime, but one conn at a time
  75. """
  76. session_id: Optional[str] = None
  77. session: Optional[Session] = None # the session that current connection attaches
  78. connection: WebSocketConnection
  79. reconnectable: bool
  80. def __init__(self, connection: WebSocketConnection, application, reconnectable: bool, ioloop=None):
  81. logger.debug("WebSocket opened")
  82. self.connection = connection
  83. self.reconnectable = reconnectable
  84. self.session_id = connection.get_query_argument('session')
  85. self.ioloop = ioloop or asyncio.get_event_loop()
  86. if self.session_id in ('NEW', None): # 初始请求,创建新 Session
  87. self._init_session(application)
  88. if reconnectable:
  89. _reconnect_state.active_connections[self.session_id] = self.connection
  90. _reconnect_state.unclosed_sessions[self.session_id] = self.session
  91. # set session id to client, so the client can send it back to server to recover a session when it
  92. # resumes form a connection lost
  93. connection.write_message(dict(command='set_session_id', spec=self.session_id))
  94. elif self.session_id not in _reconnect_state.unclosed_sessions: # session is expired
  95. bye_msg = dict(command='close_session')
  96. for m in _reconnect_state.session_will_messages.get(self.session_id, [bye_msg]):
  97. try:
  98. connection.write_message(m)
  99. except Exception:
  100. logger.exception("Error in sending message via websocket")
  101. else: # resumes form a connection lost
  102. self.session = _reconnect_state.unclosed_sessions[self.session_id]
  103. _reconnect_state.detached_sessions.pop(self.session_id, None)
  104. _reconnect_state.active_connections[self.session_id] = connection
  105. # send the latest messages to client
  106. self._send_msg_to_client()
  107. logger.debug('session id: %s' % self.session_id)
  108. def _init_session(self, application):
  109. session_info = self.connection.make_session_info()
  110. self.session_id = random_str(24)
  111. if iscoroutinefunction(application) or isgeneratorfunction(application):
  112. self.session = CoroutineBasedSession(
  113. application, session_info=session_info,
  114. on_task_command=self._send_msg_to_client,
  115. on_session_close=self._close_from_session)
  116. else:
  117. self.session = ThreadBasedSession(
  118. application, session_info=session_info,
  119. on_task_command=self._send_msg_to_client,
  120. on_session_close=self._close_from_session,
  121. loop=self.ioloop)
  122. def _get_active_connection(self) -> Optional[WebSocketConnection]:
  123. # when reconnect enabled, the active connection for this session is in _reconnect_state.active_connections,
  124. # otherwise, it's self.connection.
  125. if self.reconnectable:
  126. conn = _reconnect_state.active_connections.get(self.session_id)
  127. else:
  128. conn = self.connection
  129. return conn
  130. def _send_msg_to_client(self, session: Session = None):
  131. conn = self._get_active_connection()
  132. session = session or self.session
  133. if not conn or conn.closed():
  134. return
  135. for msg in session.get_task_commands():
  136. try:
  137. conn.write_message(msg)
  138. except TypeError as e:
  139. logger.exception('Data serialization error: %s\n'
  140. 'This may be because you pass the wrong type of parameter to the function'
  141. ' of PyWebIO.\nData content: %s', e, msg)
  142. except Exception:
  143. logger.exception("Error in sending message via websocket")
  144. def _close_from_session(self):
  145. conn = self._get_active_connection()
  146. if conn and not conn.closed():
  147. self._send_msg_to_client()
  148. conn.close()
  149. elif self.reconnectable: # no active connection, and reconnect is enabled
  150. _reconnect_state.session_will_messages[self.session_id] = self.session.get_task_commands()
  151. self.session = None
  152. def send_client_data(self, data):
  153. if isinstance(data, bytes):
  154. event = deserialize_binary_event(data)
  155. else:
  156. event = json.loads(data)
  157. if event is None:
  158. return
  159. self.session.send_client_event(event)
  160. def notify_connection_lost(self):
  161. logger.debug("WebSocket closed")
  162. if not self.reconnectable and self.session:
  163. # when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
  164. # because the `session.close()` is reentrant
  165. self.session.close(nonblock=True)
  166. self.session = None # reset the reference
  167. return
  168. _reconnect_state.active_connections.pop(self.session_id, None)
  169. if self.session_id in _reconnect_state.unclosed_sessions:
  170. _reconnect_state.detached_sessions[self.session_id] = time.time()