ws.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import abc
  2. import asyncio
  3. import json
  4. import logging
  5. import time
  6. import typing
  7. from typing import Dict, Optional
  8. from ...session import CoroutineBasedSession, Session, ThreadBasedSession
  9. from ...utils import LRUDict, iscoroutinefunction, isgeneratorfunction, random_str
  10. from ..utils import deserialize_binary_event
  11. logger = logging.getLogger(__name__)
  12. class _state:
  13. # only used in reconnect enabled
  14. # used to clean up session
  15. detached_sessions = LRUDict() # session_id -> detached timestamp. In increasing order of the time
  16. # unclosed and unexpired session
  17. # only used in reconnect enabled
  18. # used to clean up session
  19. # used to retrieve session by id when new connection
  20. unclosed_sessions: Dict[str, Session] = {} # session_id -> session
  21. # the messages that can't deliver to browser when session close due to connection lost
  22. undelivered_messages: Dict[str, list] = {} # session_id -> unhandled message list
  23. # used to get the active conn in session's callbacks
  24. active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler
  25. expire_second = 0
  26. def set_expire_second(sec):
  27. _state.expire_second = max(_state.expire_second, sec)
  28. def clean_expired_sessions():
  29. while _state.detached_sessions:
  30. session_id, detached_ts = _state.detached_sessions.popitem(last=False) # 弹出最早过期的session
  31. if time.time() < detached_ts + _state.expire_second:
  32. # this session is not expired
  33. _state.detached_sessions[session_id] = detached_ts # restore
  34. _state.detached_sessions.move_to_end(session_id, last=False) # move to head
  35. break
  36. # clean this session
  37. logger.debug("session %s expired" % session_id)
  38. _state.active_connections.pop(session_id, None)
  39. _state.undelivered_messages.pop(session_id, None)
  40. session = _state.unclosed_sessions.pop(session_id, None)
  41. if session:
  42. session.close(nonblock=True)
  43. _session_clean_task_started = False
  44. async def session_clean_task():
  45. global _session_clean_task_started
  46. if _session_clean_task_started or not _state.expire_second:
  47. return
  48. _session_clean_task_started = True
  49. logger.debug("Start session cleaning task")
  50. while True:
  51. try:
  52. clean_expired_sessions()
  53. except Exception:
  54. logger.exception("Error when clean expired sessions")
  55. await asyncio.sleep(_state.expire_second // 2)
  56. class WebSocketConnection(abc.ABC):
  57. @abc.abstractmethod
  58. def get_query_argument(self, name) -> typing.Optional[str]:
  59. pass
  60. @abc.abstractmethod
  61. def make_session_info(self) -> dict:
  62. pass
  63. @abc.abstractmethod
  64. def write_message(self, message: dict):
  65. pass
  66. @abc.abstractmethod
  67. def closed(self) -> bool:
  68. return False
  69. @abc.abstractmethod
  70. def close(self):
  71. pass
  72. class WebSocketHandler:
  73. """
  74. hold by one connection,
  75. share one session with multiple connection in session lifetime, but one conn at a time
  76. """
  77. session_id: Optional[str] = None
  78. session: Optional[Session] = None # the session that current connection attaches
  79. connection: WebSocketConnection
  80. reconnectable: bool
  81. def __init__(self, connection: WebSocketConnection, application, reconnectable: bool, ioloop=None):
  82. logger.debug("WebSocket opened")
  83. self.connection = connection
  84. self.reconnectable = reconnectable
  85. self.session_id = connection.get_query_argument('session')
  86. self.ioloop = ioloop or asyncio.get_event_loop()
  87. if self.session_id in ('NEW', None): # 初始请求,创建新 Session
  88. self._init_session(application)
  89. if reconnectable:
  90. # set session id to client, so the client can send it back to server to recover a session when it
  91. # resumes form a connection lost
  92. connection.write_message(dict(command='set_session_id', spec=self.session_id))
  93. elif self.session_id not in _state.unclosed_sessions: # session is expired
  94. bye_msg = dict(command='close_session')
  95. for m in _state.undelivered_messages.get(self.session_id, [bye_msg]):
  96. try:
  97. connection.write_message(m)
  98. except Exception:
  99. logger.exception("Error in sending message via websocket")
  100. else:
  101. self.session = _state.unclosed_sessions[self.session_id]
  102. _state.detached_sessions.pop(self.session_id, None)
  103. _state.active_connections[self.session_id] = connection
  104. # send the latest messages to client
  105. self._send_msg_to_client(self.session)
  106. logger.debug('session id: %s' % self.session_id)
  107. def _init_session(self, application):
  108. session_info = self.connection.make_session_info()
  109. self.session_id = random_str(24)
  110. # todo: only set item when reconnection enabled
  111. _state.active_connections[self.session_id] = self.connection
  112. if iscoroutinefunction(application) or isgeneratorfunction(application):
  113. self.session = CoroutineBasedSession(
  114. application, session_info=session_info,
  115. on_task_command=self._send_msg_to_client,
  116. on_session_close=self._close_from_session)
  117. else:
  118. self.session = ThreadBasedSession(
  119. application, session_info=session_info,
  120. on_task_command=self._send_msg_to_client,
  121. on_session_close=self._close_from_session,
  122. loop=self.ioloop)
  123. _state.unclosed_sessions[self.session_id] = self.session
  124. def _send_msg_to_client(self, session):
  125. # self.connection may not be active,
  126. # here we need the active connection for this session
  127. conn = _state.active_connections.get(self.session_id)
  128. if not conn or conn.closed():
  129. return
  130. for msg in session.get_task_commands():
  131. try:
  132. conn.write_message(msg)
  133. except TypeError as e:
  134. logger.exception('Data serialization error: %s\n'
  135. 'This may be because you pass the wrong type of parameter to the function'
  136. ' of PyWebIO.\nData content: %s', e, msg)
  137. except Exception:
  138. logger.exception("Error in sending message via websocket")
  139. def _close_from_session(self):
  140. session = _state.unclosed_sessions[self.session_id]
  141. if self.session_id in _state.active_connections:
  142. # send the undelivered messages to client
  143. self._send_msg_to_client(session=session)
  144. else:
  145. _state.undelivered_messages[self.session_id] = session.get_task_commands()
  146. conn = _state.active_connections.pop(self.session_id, None)
  147. _state.unclosed_sessions.pop(self.session_id, None)
  148. if conn and not conn.closed():
  149. conn.close()
  150. def send_client_data(self, data):
  151. if isinstance(data, bytes):
  152. event = deserialize_binary_event(data)
  153. else:
  154. event = json.loads(data)
  155. if event is None:
  156. return
  157. self.session.send_client_event(event)
  158. def notify_connection_lost(self):
  159. _state.active_connections.pop(self.session_id, None)
  160. if not self.reconnectable:
  161. # when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
  162. # because the `session.close()` is reentrant
  163. self.session.close(nonblock=True)
  164. else:
  165. if self.session_id in _state.unclosed_sessions:
  166. _state.detached_sessions[self.session_id] = time.time()
  167. logger.debug("WebSocket closed")