|
@@ -14,12 +14,13 @@ import logging
|
|
|
import threading
|
|
|
import time
|
|
|
from contextlib import contextmanager
|
|
|
-from typing import Dict, Optional
|
|
|
+from typing import Dict, Optional, List
|
|
|
+from collections import deque
|
|
|
|
|
|
from ..page import make_applications, render_page
|
|
|
from ..utils import deserialize_binary_event
|
|
|
from ...session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target
|
|
|
-from ...session.base import get_session_info_from_headers
|
|
|
+from ...session.base import get_session_info_from_headers, Session
|
|
|
from ...utils import random_str, LRUDict, isgeneratorfunction, iscoroutinefunction, check_webio_js
|
|
|
|
|
|
|
|
@@ -35,7 +36,7 @@ class HttpContext:
|
|
|
Return the current request object"""
|
|
|
pass
|
|
|
|
|
|
- def request_method(self):
|
|
|
+ def request_method(self) -> str:
|
|
|
"""返回当前请求的方法,大写
|
|
|
Return the HTTP method of the current request, uppercase"""
|
|
|
pass
|
|
@@ -45,12 +46,12 @@ class HttpContext:
|
|
|
Return the header dictionary of the current request"""
|
|
|
pass
|
|
|
|
|
|
- def request_url_parameter(self, name, default=None):
|
|
|
+ def request_url_parameter(self, name, default=None) -> str:
|
|
|
"""返回当前请求的URL参数
|
|
|
Returns the value of the given URL parameter of the current request"""
|
|
|
pass
|
|
|
|
|
|
- def request_body(self):
|
|
|
+ def request_body(self) -> bytes:
|
|
|
"""返回当前请求的body数据
|
|
|
Returns the data of the current request body
|
|
|
|
|
@@ -58,16 +59,6 @@ class HttpContext:
|
|
|
"""
|
|
|
return b''
|
|
|
|
|
|
- def request_json(self) -> Optional[Dict]:
|
|
|
- """返回当前请求的json反序列化后的内容,若请求数据不为json格式,返回None
|
|
|
- Return the data (json deserialization) of the currently requested, if the data is not in json format, return None"""
|
|
|
- try:
|
|
|
- if self.request_headers().get('content-type') == 'application/octet-stream':
|
|
|
- return deserialize_binary_event(self.request_body())
|
|
|
- return json.loads(self.request_body())
|
|
|
- except Exception:
|
|
|
- return None
|
|
|
-
|
|
|
def set_header(self, name, value):
|
|
|
"""为当前响应设置header
|
|
|
Set a header for the current response"""
|
|
@@ -92,7 +83,7 @@ class HttpContext:
|
|
|
Get the current response object"""
|
|
|
pass
|
|
|
|
|
|
- def get_client_ip(self):
|
|
|
+ def get_client_ip(self) -> str:
|
|
|
"""获取用户的ip
|
|
|
Get the user's ip"""
|
|
|
pass
|
|
@@ -102,6 +93,56 @@ logger = logging.getLogger(__name__)
|
|
|
_event_loop = None
|
|
|
|
|
|
|
|
|
+class ReliableTransport:
|
|
|
+ def __init__(self, session: Session, message_window: int = 4):
|
|
|
+ self.session = session
|
|
|
+ self.messages = deque()
|
|
|
+ self.window_size = message_window
|
|
|
+ self.min_msg_id = 0 # the id of the first message in the window
|
|
|
+ self.finished_event_id = -1 # the id of the last finished event
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def close_message(ack):
|
|
|
+ return dict(
|
|
|
+ commands=[[dict(command='close_session')]],
|
|
|
+ seq=ack + 1
|
|
|
+ )
|
|
|
+
|
|
|
+ def push_event(self, events: List[Dict], seq: int) -> int:
|
|
|
+ """Send client events to the session and return the success message count"""
|
|
|
+ if not events:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ submit_cnt = 0
|
|
|
+ for eid, event in enumerate(events, start=seq):
|
|
|
+ if eid > self.finished_event_id:
|
|
|
+ self.finished_event_id = eid # todo: use lock for check and set operation
|
|
|
+ self.session.send_client_event(event)
|
|
|
+ submit_cnt += 1
|
|
|
+
|
|
|
+ return submit_cnt
|
|
|
+
|
|
|
+ def get_response(self, ack=0):
|
|
|
+ """
|
|
|
+ ack num is the number of messages that the client has received.
|
|
|
+ response is a list of messages that the client should receive, along with their min id `seq`.
|
|
|
+ """
|
|
|
+ while ack >= self.min_msg_id and self.messages:
|
|
|
+ self.messages.popleft()
|
|
|
+ self.min_msg_id += 1
|
|
|
+
|
|
|
+ if len(self.messages) < self.window_size:
|
|
|
+ msgs = self.session.get_task_commands()
|
|
|
+ if msgs:
|
|
|
+ self.messages.append(msgs)
|
|
|
+
|
|
|
+ return dict(
|
|
|
+ commands=list(self.messages),
|
|
|
+ seq=self.min_msg_id,
|
|
|
+ ack=self.finished_event_id
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
# todo: use lock to avoid thread race condition
|
|
|
class HttpHandler:
|
|
|
"""基于HTTP的后端Handler实现
|
|
@@ -112,7 +153,7 @@ class HttpHandler:
|
|
|
|
|
|
"""
|
|
|
_webio_sessions = {} # WebIOSessionID -> WebIOSession()
|
|
|
- _webio_last_commands = {} # WebIOSessionID -> (last commands, commands sequence id)
|
|
|
+ _webio_transports = {} # WebIOSessionID -> ReliableTransport(), type: Dict[str, ReliableTransport]
|
|
|
_webio_expire = LRUDict() # WebIOSessionID -> last active timestamp. In increasing order of last active time
|
|
|
_webio_expire_lock = threading.Lock()
|
|
|
|
|
@@ -143,23 +184,13 @@ class HttpHandler:
|
|
|
if session:
|
|
|
session.close(nonblock=True)
|
|
|
del cls._webio_sessions[sid]
|
|
|
+ del cls._webio_transports[sid]
|
|
|
|
|
|
@classmethod
|
|
|
def _remove_webio_session(cls, sid):
|
|
|
cls._webio_sessions.pop(sid, None)
|
|
|
cls._webio_expire.pop(sid, None)
|
|
|
|
|
|
- @classmethod
|
|
|
- def get_response(cls, sid, ack=0):
|
|
|
- commands, seq = cls._webio_last_commands.get(sid, ([], 0))
|
|
|
- if ack == seq:
|
|
|
- webio_session = cls._webio_sessions[sid]
|
|
|
- commands = webio_session.get_task_commands()
|
|
|
- seq += 1
|
|
|
- cls._webio_last_commands[sid] = (commands, seq)
|
|
|
-
|
|
|
- return {'commands': commands, 'seq': seq}
|
|
|
-
|
|
|
def _process_cors(self, context: HttpContext):
|
|
|
"""Handling cross-domain requests: check the source of the request and set headers"""
|
|
|
origin = context.request_headers().get('Origin', '')
|
|
@@ -209,6 +240,14 @@ class HttpHandler:
|
|
|
return False
|
|
|
return self.cdn
|
|
|
|
|
|
+ def read_event_data(self, context: HttpContext) -> List[Dict]:
|
|
|
+ try:
|
|
|
+ if context.request_headers().get('content-type') == 'application/octet-stream':
|
|
|
+ return [deserialize_binary_event(context.request_body())]
|
|
|
+ return json.loads(context.request_body())
|
|
|
+ except Exception:
|
|
|
+ return []
|
|
|
+
|
|
|
@contextmanager
|
|
|
def handle_request_context(self, context: HttpContext):
|
|
|
"""called when every http request"""
|
|
@@ -240,16 +279,18 @@ class HttpHandler:
|
|
|
context.set_content(html)
|
|
|
return context.get_response()
|
|
|
|
|
|
- webio_session_id = None
|
|
|
+ ack = int(context.request_url_parameter('ack', 0))
|
|
|
+ webio_session_id = request_headers['webio-session-id']
|
|
|
+ new_request = False
|
|
|
+ if webio_session_id.startswith('NEW-'):
|
|
|
+ new_request = True
|
|
|
+ webio_session_id = webio_session_id[4:]
|
|
|
|
|
|
- # 初始请求,创建新 Session
|
|
|
- if not request_headers['webio-session-id'] or request_headers['webio-session-id'] == 'NEW':
|
|
|
+ if new_request and webio_session_id not in cls._webio_sessions: # 初始请求,创建新 Session
|
|
|
if context.request_method() == 'POST': # 不能在POST请求中创建Session,防止CSRF攻击
|
|
|
context.set_status(403)
|
|
|
return context.get_response()
|
|
|
|
|
|
- webio_session_id = random_str(24)
|
|
|
- context.set_header('webio-session-id', webio_session_id)
|
|
|
session_info = get_session_info_from_headers(context.request_headers())
|
|
|
session_info['user_ip'] = context.get_client_ip()
|
|
|
session_info['request'] = context.request_obj()
|
|
@@ -264,17 +305,23 @@ class HttpHandler:
|
|
|
session_cls = ThreadBasedSession
|
|
|
webio_session = session_cls(application, session_info=session_info)
|
|
|
cls._webio_sessions[webio_session_id] = webio_session
|
|
|
- yield type(self).WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
|
|
|
- elif request_headers['webio-session-id'] not in cls._webio_sessions: # WebIOSession deleted
|
|
|
- context.set_content([dict(command='close_session')], json_type=True)
|
|
|
+ cls._webio_transports[webio_session_id] = ReliableTransport(webio_session)
|
|
|
+ yield cls.WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
|
|
|
+ elif webio_session_id not in cls._webio_sessions: # WebIOSession deleted
|
|
|
+ close_msg = ReliableTransport.close_message(ack)
|
|
|
+ context.set_content(close_msg, json_type=True)
|
|
|
return context.get_response()
|
|
|
else:
|
|
|
- webio_session_id = request_headers['webio-session-id']
|
|
|
+ # in this case, the request_headers['webio-session-id'] may also startswith NEW,
|
|
|
+ # this is because the response for the previous new session request has not been received by the client,
|
|
|
+ # and the client has sent a new request with the same session id.
|
|
|
webio_session = cls._webio_sessions[webio_session_id]
|
|
|
|
|
|
if context.request_method() == 'POST': # client push event
|
|
|
- if context.request_json() is not None:
|
|
|
- webio_session.send_client_event(context.request_json())
|
|
|
+ seq = int(context.request_url_parameter('seq', 0))
|
|
|
+ event_data = self.read_event_data(context)
|
|
|
+ submit_cnt = cls._webio_transports[webio_session_id].push_event(event_data, seq)
|
|
|
+ if submit_cnt > 0:
|
|
|
yield type(self).WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
|
|
|
elif context.request_method() == 'GET': # client pull messages
|
|
|
pass
|
|
@@ -283,8 +330,8 @@ class HttpHandler:
|
|
|
|
|
|
self.interval_cleaning()
|
|
|
|
|
|
- ack = int(context.request_url_parameter('ack', 0))
|
|
|
- context.set_content(type(self).get_response(webio_session_id, ack=ack), json_type=True)
|
|
|
+ resp = cls._webio_transports[webio_session_id].get_response(ack)
|
|
|
+ context.set_content(resp, json_type=True)
|
|
|
|
|
|
if webio_session.closed():
|
|
|
self._remove_webio_session(webio_session_id)
|