ws.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import asyncio
  2. import json
  3. import logging
  4. import time
  5. import typing
  6. from typing import Dict
  7. import abc
  8. from ..utils import deserialize_binary_event
  9. from ...session import CoroutineBasedSession, ThreadBasedSession, Session
  10. from ...utils import iscoroutinefunction, isgeneratorfunction, \
  11. random_str, LRUDict
  12. logger = logging.getLogger(__name__)
  13. class _state:
  14. # only used in reconnect enabled
  15. # used to clean up session
  16. detached_sessions = LRUDict() # session_id -> detached timestamp. In increasing order of the time
  17. # unclosed and unexpired session
  18. # only used in reconnect enabled
  19. # used to clean up session
  20. # used to retrieve session by id when new connection
  21. unclosed_sessions: Dict[str, Session] = {} # session_id -> session
  22. # the messages that can't deliver to browser when session close due to connection lost
  23. undelivered_messages: Dict[str, list] = {} # session_id -> unhandled message list
  24. # used to get the active conn in session's callbacks
  25. active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler
  26. expire_second = 10
  27. def set_expire_second(sec):
  28. _state.expire_second = max(_state.expire_second, sec)
  29. def clean_expired_sessions():
  30. while _state.detached_sessions:
  31. session_id, detached_ts = _state.detached_sessions.popitem(last=False) # 弹出最早过期的session
  32. if time.time() < detached_ts + _state.expire_second:
  33. # this session is not expired
  34. _state.detached_sessions[session_id] = detached_ts # restore
  35. _state.detached_sessions.move_to_end(session_id, last=False) # move to head
  36. break
  37. # clean this session
  38. logger.debug("session %s expired" % session_id)
  39. _state.active_connections.pop(session_id, None)
  40. _state.undelivered_messages.pop(session_id, None)
  41. session = _state.unclosed_sessions.pop(session_id, None)
  42. if session:
  43. session.close(nonblock=True)
  44. async def session_clean_task():
  45. logger.debug("Start session cleaning task")
  46. while True:
  47. try:
  48. clean_expired_sessions()
  49. except Exception:
  50. logger.exception("Error when clean expired sessions")
  51. await asyncio.sleep(_state.expire_second // 2)
  52. class WebSocketConnection(abc.ABC):
  53. @abc.abstractmethod
  54. def get_query_argument(self, name) -> typing.Optional[str]:
  55. pass
  56. @abc.abstractmethod
  57. def make_session_info(self) -> dict:
  58. pass
  59. @abc.abstractmethod
  60. def write_message(self, message: dict):
  61. pass
  62. @abc.abstractmethod
  63. def closed(self) -> bool:
  64. return False
  65. @abc.abstractmethod
  66. def close(self):
  67. pass
  68. class WebSocketHandler:
  69. """
  70. hold by one connection,
  71. share one session with multiple connection in session lifetime, but one conn at a time
  72. """
  73. session_id: str = None
  74. session: Session = None # the session that current connection attaches
  75. connection: WebSocketConnection
  76. reconnectable: bool
  77. def __init__(self, connection: WebSocketConnection, application, reconnectable: bool):
  78. logger.debug("WebSocket opened")
  79. self.connection = connection
  80. self.reconnectable = reconnectable
  81. self.session_id = connection.get_query_argument('session')
  82. if self.session_id in ('NEW', None): # 初始请求,创建新 Session
  83. self._init_session(application)
  84. if reconnectable:
  85. # set session id to client, so the client can send it back to server to recover a session when it
  86. # resumes form a connection lost
  87. connection.write_message(dict(command='set_session_id', spec=self.session_id))
  88. elif self.session_id not in _state.unclosed_sessions: # session is expired
  89. bye_msg = dict(command='close_session')
  90. for m in _state.undelivered_messages.get(self.session_id, [bye_msg]):
  91. try:
  92. connection.write_message(m)
  93. except Exception:
  94. logger.exception("Error in sending message via websocket")
  95. else:
  96. self.session = _state.unclosed_sessions[self.session_id]
  97. _state.detached_sessions.pop(self.session_id, None)
  98. _state.active_connections[self.session_id] = connection
  99. # send the latest messages to client
  100. self._send_msg_to_client(self.session)
  101. logger.debug('session id: %s' % self.session_id)
  102. def _init_session(self, application):
  103. session_info = self.connection.make_session_info()
  104. self.session_id = random_str(24)
  105. # todo: only set item when reconnection enabled
  106. _state.active_connections[self.session_id] = self.connection
  107. if iscoroutinefunction(application) or isgeneratorfunction(application):
  108. self.session = CoroutineBasedSession(
  109. application, session_info=session_info,
  110. on_task_command=self._send_msg_to_client,
  111. on_session_close=self._close_from_session)
  112. else:
  113. self.session = ThreadBasedSession(
  114. application, session_info=session_info,
  115. on_task_command=self._send_msg_to_client,
  116. on_session_close=self._close_from_session,
  117. loop=asyncio.get_event_loop())
  118. _state.unclosed_sessions[self.session_id] = self.session
  119. def _send_msg_to_client(self, session):
  120. # self.connection may not be active,
  121. # here we need the active connection for this session
  122. conn = _state.active_connections.get(self.session_id)
  123. if not conn or conn.closed():
  124. return
  125. for msg in session.get_task_commands():
  126. try:
  127. conn.write_message(msg)
  128. except TypeError as e:
  129. logger.exception('Data serialization error: %s\n'
  130. 'This may be because you pass the wrong type of parameter to the function'
  131. ' of PyWebIO.\nData content: %s', e, msg)
  132. except Exception:
  133. logger.exception("Error in sending message via websocket")
  134. def _close_from_session(self):
  135. session = _state.unclosed_sessions[self.session_id]
  136. if self.session_id in _state.active_connections:
  137. # send the undelivered messages to client
  138. self._send_msg_to_client(session=session)
  139. else:
  140. _state.undelivered_messages[self.session_id] = session.get_task_commands()
  141. conn = _state.active_connections.pop(self.session_id, None)
  142. _state.unclosed_sessions.pop(self.session_id, None)
  143. if conn and not conn.closed():
  144. conn.close()
  145. def send_client_data(self, data):
  146. if isinstance(data, bytes):
  147. event = deserialize_binary_event(data)
  148. else:
  149. event = json.loads(data)
  150. if event is None:
  151. return
  152. self.session.send_client_event(event)
  153. def notify_connection_lost(self):
  154. _state.active_connections.pop(self.session_id, None)
  155. if not self.reconnectable:
  156. # when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
  157. # because the `session.close()` is reentrant
  158. self.session.close(nonblock=True)
  159. else:
  160. if self.session_id in _state.unclosed_sessions:
  161. _state.detached_sessions[self.session_id] = time.time()
  162. logger.debug("WebSocket closed")