ws.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import asyncio
  2. import json
  3. import logging
  4. import time
  5. import typing
  6. from typing import Dict
  7. import abc
  8. from ..utils import deserialize_binary_event
  9. from ...session import CoroutineBasedSession, ThreadBasedSession, Session
  10. from ...utils import iscoroutinefunction, isgeneratorfunction, \
  11. random_str, LRUDict
  12. logger = logging.getLogger(__name__)
  13. class _state:
  14. # only used in reconnect enabled
  15. # used to clean up session
  16. detached_sessions = LRUDict() # session_id -> detached timestamp. In increasing order of the time
  17. # unclosed and unexpired session
  18. # only used in reconnect enabled
  19. # used to clean up session
  20. # used to retrieve session by id when new connection
  21. unclosed_sessions: Dict[str, Session] = {} # session_id -> session
  22. # the messages that can't deliver to browser when session close due to connection lost
  23. undelivered_messages: Dict[str, list] = {} # session_id -> unhandled message list
  24. # used to get the active conn in session's callbacks
  25. active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler
  26. expire_second = 10
  27. def set_expire_second(sec):
  28. _state.expire_second = max(_state.expire_second, sec)
  29. def clean_expired_sessions():
  30. while _state.detached_sessions:
  31. session_id, detached_ts = _state.detached_sessions.popitem(last=False) # 弹出最早过期的session
  32. if time.time() < detached_ts + _state.expire_second:
  33. # this session is not expired
  34. _state.detached_sessions[session_id] = detached_ts # restore
  35. _state.detached_sessions.move_to_end(session_id, last=False) # move to head
  36. break
  37. # clean this session
  38. logger.debug("session %s expired" % session_id)
  39. _state.active_connections.pop(session_id, None)
  40. _state.undelivered_messages.pop(session_id, None)
  41. session = _state.unclosed_sessions.pop(session_id, None)
  42. if session:
  43. session.close(nonblock=True)
  44. async def session_clean_task():
  45. logger.debug("Start session cleaning task")
  46. while True:
  47. try:
  48. clean_expired_sessions()
  49. except Exception:
  50. logger.exception("Error when clean expired sessions")
  51. await asyncio.sleep(_state.expire_second // 2)
  52. class WebSocketConnection(abc.ABC):
  53. @abc.abstractmethod
  54. def get_query_argument(self, name) -> typing.Optional[str]:
  55. pass
  56. @abc.abstractmethod
  57. def make_session_info(self) -> dict:
  58. pass
  59. @abc.abstractmethod
  60. def write_message(self, message: dict):
  61. pass
  62. @abc.abstractmethod
  63. def closed(self) -> bool:
  64. return False
  65. @abc.abstractmethod
  66. def close(self):
  67. pass
  68. class WebSocketHandler:
  69. """
  70. hold by one connection,
  71. share one session with multiple connection in session lifetime, but one conn at a time
  72. """
  73. session_id: str = None
  74. session: Session = None # the session that current connection attaches
  75. connection: WebSocketConnection
  76. reconnectable: bool
  77. def __init__(self, connection: WebSocketConnection, application, reconnectable: bool, ioloop=None):
  78. logger.debug("WebSocket opened")
  79. self.connection = connection
  80. self.reconnectable = reconnectable
  81. self.session_id = connection.get_query_argument('session')
  82. self.ioloop = ioloop or asyncio.get_event_loop()
  83. if self.session_id in ('NEW', None): # 初始请求,创建新 Session
  84. self._init_session(application)
  85. if reconnectable:
  86. # set session id to client, so the client can send it back to server to recover a session when it
  87. # resumes form a connection lost
  88. connection.write_message(dict(command='set_session_id', spec=self.session_id))
  89. elif self.session_id not in _state.unclosed_sessions: # session is expired
  90. bye_msg = dict(command='close_session')
  91. for m in _state.undelivered_messages.get(self.session_id, [bye_msg]):
  92. try:
  93. connection.write_message(m)
  94. except Exception:
  95. logger.exception("Error in sending message via websocket")
  96. else:
  97. self.session = _state.unclosed_sessions[self.session_id]
  98. _state.detached_sessions.pop(self.session_id, None)
  99. _state.active_connections[self.session_id] = connection
  100. # send the latest messages to client
  101. self._send_msg_to_client(self.session)
  102. logger.debug('session id: %s' % self.session_id)
  103. def _init_session(self, application):
  104. session_info = self.connection.make_session_info()
  105. self.session_id = random_str(24)
  106. # todo: only set item when reconnection enabled
  107. _state.active_connections[self.session_id] = self.connection
  108. if iscoroutinefunction(application) or isgeneratorfunction(application):
  109. self.session = CoroutineBasedSession(
  110. application, session_info=session_info,
  111. on_task_command=self._send_msg_to_client,
  112. on_session_close=self._close_from_session)
  113. else:
  114. self.session = ThreadBasedSession(
  115. application, session_info=session_info,
  116. on_task_command=self._send_msg_to_client,
  117. on_session_close=self._close_from_session,
  118. loop=self.ioloop)
  119. _state.unclosed_sessions[self.session_id] = self.session
  120. def _send_msg_to_client(self, session):
  121. # self.connection may not be active,
  122. # here we need the active connection for this session
  123. conn = _state.active_connections.get(self.session_id)
  124. if not conn or conn.closed():
  125. return
  126. for msg in session.get_task_commands():
  127. try:
  128. conn.write_message(msg)
  129. except TypeError as e:
  130. logger.exception('Data serialization error: %s\n'
  131. 'This may be because you pass the wrong type of parameter to the function'
  132. ' of PyWebIO.\nData content: %s', e, msg)
  133. except Exception:
  134. logger.exception("Error in sending message via websocket")
  135. def _close_from_session(self):
  136. session = _state.unclosed_sessions[self.session_id]
  137. if self.session_id in _state.active_connections:
  138. # send the undelivered messages to client
  139. self._send_msg_to_client(session=session)
  140. else:
  141. _state.undelivered_messages[self.session_id] = session.get_task_commands()
  142. conn = _state.active_connections.pop(self.session_id, None)
  143. _state.unclosed_sessions.pop(self.session_id, None)
  144. if conn and not conn.closed():
  145. conn.close()
  146. def send_client_data(self, data):
  147. if isinstance(data, bytes):
  148. event = deserialize_binary_event(data)
  149. else:
  150. event = json.loads(data)
  151. if event is None:
  152. return
  153. self.session.send_client_event(event)
  154. def notify_connection_lost(self):
  155. _state.active_connections.pop(self.session_id, None)
  156. if not self.reconnectable:
  157. # when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
  158. # because the `session.close()` is reentrant
  159. self.session.close(nonblock=True)
  160. else:
  161. if self.session_id in _state.unclosed_sessions:
  162. _state.detached_sessions[self.session_id] = time.time()
  163. logger.debug("WebSocket closed")