ws.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import asyncio
  2. import json
  3. import logging
  4. import time
  5. import typing
  6. from typing import Dict
  7. import abc
  8. from ..utils import deserialize_binary_event
  9. from ...session import CoroutineBasedSession, ThreadBasedSession, Session
  10. from ...utils import iscoroutinefunction, isgeneratorfunction, \
  11. random_str, LRUDict
  12. logger = logging.getLogger(__name__)
  13. class _state:
  14. # only used in reconnect enabled
  15. # used to clean up session
  16. detached_sessions = LRUDict() # session_id -> detached timestamp. In increasing order of the time
  17. # unclosed and unexpired session
  18. # only used in reconnect enabled
  19. # used to clean up session
  20. # used to retrieve session by id when new connection
  21. unclosed_sessions: Dict[str, Session] = {} # session_id -> session
  22. # the messages that can't deliver to browser when session close due to connection lost
  23. undelivered_messages: Dict[str, list] = {} # session_id -> unhandled message list
  24. # used to get the active conn in session's callbacks
  25. active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler
  26. expire_second = 0
  27. def set_expire_second(sec):
  28. _state.expire_second = max(_state.expire_second, sec)
  29. def clean_expired_sessions():
  30. while _state.detached_sessions:
  31. session_id, detached_ts = _state.detached_sessions.popitem(last=False) # 弹出最早过期的session
  32. if time.time() < detached_ts + _state.expire_second:
  33. # this session is not expired
  34. _state.detached_sessions[session_id] = detached_ts # restore
  35. _state.detached_sessions.move_to_end(session_id, last=False) # move to head
  36. break
  37. # clean this session
  38. logger.debug("session %s expired" % session_id)
  39. _state.active_connections.pop(session_id, None)
  40. _state.undelivered_messages.pop(session_id, None)
  41. session = _state.unclosed_sessions.pop(session_id, None)
  42. if session:
  43. session.close(nonblock=True)
  44. _session_clean_task_started = False
  45. async def session_clean_task():
  46. global _session_clean_task_started
  47. if _session_clean_task_started or not _state.expire_second:
  48. return
  49. _session_clean_task_started = True
  50. logger.debug("Start session cleaning task")
  51. while True:
  52. try:
  53. clean_expired_sessions()
  54. except Exception:
  55. logger.exception("Error when clean expired sessions")
  56. await asyncio.sleep(_state.expire_second // 2)
  57. class WebSocketConnection(abc.ABC):
  58. @abc.abstractmethod
  59. def get_query_argument(self, name) -> typing.Optional[str]:
  60. pass
  61. @abc.abstractmethod
  62. def make_session_info(self) -> dict:
  63. pass
  64. @abc.abstractmethod
  65. def write_message(self, message: dict):
  66. pass
  67. @abc.abstractmethod
  68. def closed(self) -> bool:
  69. return False
  70. @abc.abstractmethod
  71. def close(self):
  72. pass
  73. class WebSocketHandler:
  74. """
  75. hold by one connection,
  76. share one session with multiple connection in session lifetime, but one conn at a time
  77. """
  78. session_id: str = None
  79. session: Session = None # the session that current connection attaches
  80. connection: WebSocketConnection
  81. reconnectable: bool
  82. def __init__(self, connection: WebSocketConnection, application, reconnectable: bool, ioloop=None):
  83. logger.debug("WebSocket opened")
  84. self.connection = connection
  85. self.reconnectable = reconnectable
  86. self.session_id = connection.get_query_argument('session')
  87. self.ioloop = ioloop or asyncio.get_event_loop()
  88. if self.session_id in ('NEW', None): # 初始请求,创建新 Session
  89. self._init_session(application)
  90. if reconnectable:
  91. # set session id to client, so the client can send it back to server to recover a session when it
  92. # resumes form a connection lost
  93. connection.write_message(dict(command='set_session_id', spec=self.session_id))
  94. elif self.session_id not in _state.unclosed_sessions: # session is expired
  95. bye_msg = dict(command='close_session')
  96. for m in _state.undelivered_messages.get(self.session_id, [bye_msg]):
  97. try:
  98. connection.write_message(m)
  99. except Exception:
  100. logger.exception("Error in sending message via websocket")
  101. else:
  102. self.session = _state.unclosed_sessions[self.session_id]
  103. _state.detached_sessions.pop(self.session_id, None)
  104. _state.active_connections[self.session_id] = connection
  105. # send the latest messages to client
  106. self._send_msg_to_client(self.session)
  107. logger.debug('session id: %s' % self.session_id)
  108. def _init_session(self, application):
  109. session_info = self.connection.make_session_info()
  110. self.session_id = random_str(24)
  111. # todo: only set item when reconnection enabled
  112. _state.active_connections[self.session_id] = self.connection
  113. if iscoroutinefunction(application) or isgeneratorfunction(application):
  114. self.session = CoroutineBasedSession(
  115. application, session_info=session_info,
  116. on_task_command=self._send_msg_to_client,
  117. on_session_close=self._close_from_session)
  118. else:
  119. self.session = ThreadBasedSession(
  120. application, session_info=session_info,
  121. on_task_command=self._send_msg_to_client,
  122. on_session_close=self._close_from_session,
  123. loop=self.ioloop)
  124. _state.unclosed_sessions[self.session_id] = self.session
  125. def _send_msg_to_client(self, session):
  126. # self.connection may not be active,
  127. # here we need the active connection for this session
  128. conn = _state.active_connections.get(self.session_id)
  129. if not conn or conn.closed():
  130. return
  131. for msg in session.get_task_commands():
  132. try:
  133. conn.write_message(msg)
  134. except TypeError as e:
  135. logger.exception('Data serialization error: %s\n'
  136. 'This may be because you pass the wrong type of parameter to the function'
  137. ' of PyWebIO.\nData content: %s', e, msg)
  138. except Exception:
  139. logger.exception("Error in sending message via websocket")
  140. def _close_from_session(self):
  141. session = _state.unclosed_sessions[self.session_id]
  142. if self.session_id in _state.active_connections:
  143. # send the undelivered messages to client
  144. self._send_msg_to_client(session=session)
  145. else:
  146. _state.undelivered_messages[self.session_id] = session.get_task_commands()
  147. conn = _state.active_connections.pop(self.session_id, None)
  148. _state.unclosed_sessions.pop(self.session_id, None)
  149. if conn and not conn.closed():
  150. conn.close()
  151. def send_client_data(self, data):
  152. if isinstance(data, bytes):
  153. event = deserialize_binary_event(data)
  154. else:
  155. event = json.loads(data)
  156. if event is None:
  157. return
  158. self.session.send_client_event(event)
  159. def notify_connection_lost(self):
  160. _state.active_connections.pop(self.session_id, None)
  161. if not self.reconnectable:
  162. # when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
  163. # because the `session.close()` is reentrant
  164. self.session.close(nonblock=True)
  165. else:
  166. if self.session_id in _state.unclosed_sessions:
  167. _state.detached_sessions[self.session_id] = time.time()
  168. logger.debug("WebSocket closed")