فهرست منبع

Merge pull request #532 from pywebio/reliable-http-session

Reliable http session
WangWeimin 2 سال پیش
والد
کامیت
732ebc5832
4فایلهای تغییر یافته به همراه248 افزوده شده و 89 حذف شده
  1. 2 2
      pywebio/__version__.py
  2. 88 41
      pywebio/platform/adaptor/http.py
  3. 80 46
      webiojs/src/session.ts
  4. 78 0
      webiojs/src/utils.ts

+ 2 - 2
pywebio/__version__.py

@@ -1,8 +1,8 @@
 __package__ = 'pywebio'
 __description__ = 'Write interactive web app in script way.'
 __url__ = 'https://pywebio.readthedocs.io'
-__version__ = "1.7.0"
-__version_info__ = (1, 7, 0, 0)
+__version__ = "1.7.1"
+__version_info__ = (1, 7, 1, 0)
 __author__ = 'WangWeimin'
 __author_email__ = 'wang0.618@qq.com'
 __license__ = 'MIT'

+ 88 - 41
pywebio/platform/adaptor/http.py

@@ -14,12 +14,13 @@ import logging
 import threading
 import time
 from contextlib import contextmanager
-from typing import Dict, Optional
+from typing import Dict, Optional, List
+from collections import deque
 
 from ..page import make_applications, render_page
 from ..utils import deserialize_binary_event
 from ...session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target
-from ...session.base import get_session_info_from_headers
+from ...session.base import get_session_info_from_headers, Session
 from ...utils import random_str, LRUDict, isgeneratorfunction, iscoroutinefunction, check_webio_js
 
 
@@ -35,7 +36,7 @@ class HttpContext:
         Return the current request object"""
         pass
 
-    def request_method(self):
+    def request_method(self) -> str:
         """返回当前请求的方法,大写
         Return the HTTP method of the current request, uppercase"""
         pass
@@ -45,12 +46,12 @@ class HttpContext:
         Return the header dictionary of the current request"""
         pass
 
-    def request_url_parameter(self, name, default=None):
+    def request_url_parameter(self, name, default=None) -> str:
         """返回当前请求的URL参数
         Returns the value of the given URL parameter of the current request"""
         pass
 
-    def request_body(self):
+    def request_body(self) -> bytes:
         """返回当前请求的body数据
         Returns the data of the current request body
 
@@ -58,16 +59,6 @@ class HttpContext:
         """
         return b''
 
-    def request_json(self) -> Optional[Dict]:
-        """返回当前请求的json反序列化后的内容,若请求数据不为json格式,返回None
-        Return the data (json deserialization) of the currently requested, if the data is not in json format, return None"""
-        try:
-            if self.request_headers().get('content-type') == 'application/octet-stream':
-                return deserialize_binary_event(self.request_body())
-            return json.loads(self.request_body())
-        except Exception:
-            return None
-
     def set_header(self, name, value):
         """为当前响应设置header
         Set a header for the current response"""
@@ -92,7 +83,7 @@ class HttpContext:
         Get the current response object"""
         pass
 
-    def get_client_ip(self):
+    def get_client_ip(self) -> str:
         """获取用户的ip
         Get the user's ip"""
         pass
@@ -102,6 +93,56 @@ logger = logging.getLogger(__name__)
 _event_loop = None
 
 
+class ReliableTransport:
+    def __init__(self, session: Session, message_window: int = 4):
+        self.session = session
+        self.messages = deque()
+        self.window_size = message_window
+        self.min_msg_id = 0  # the id of the first message in the window
+        self.finished_event_id = -1  # the id of the last finished event
+
+    @staticmethod
+    def close_message(ack):
+        return dict(
+            commands=[[dict(command='close_session')]],
+            seq=ack + 1
+        )
+
+    def push_event(self, events: List[Dict], seq: int) -> int:
+        """Send client events to the session and return the success message count"""
+        if not events:
+            return 0
+
+        submit_cnt = 0
+        for eid, event in enumerate(events, start=seq):
+            if eid > self.finished_event_id:
+                self.finished_event_id = eid  # todo: use lock for check and set operation
+                self.session.send_client_event(event)
+                submit_cnt += 1
+
+        return submit_cnt
+
+    def get_response(self, ack=0):
+        """
+        ack num is the number of messages that the client has received.
+        response is a list of messages that the client should receive, along with their min id `seq`.
+        """
+        while ack >= self.min_msg_id and self.messages:
+            self.messages.popleft()
+            self.min_msg_id += 1
+
+        if len(self.messages) < self.window_size:
+            msgs = self.session.get_task_commands()
+            if msgs:
+                self.messages.append(msgs)
+
+        return dict(
+            commands=list(self.messages),
+            seq=self.min_msg_id,
+            ack=self.finished_event_id
+        )
+
+
 # todo: use lock to avoid thread race condition
 class HttpHandler:
     """基于HTTP的后端Handler实现
@@ -112,7 +153,7 @@ class HttpHandler:
 
     """
     _webio_sessions = {}  # WebIOSessionID -> WebIOSession()
-    _webio_last_commands = {}  # WebIOSessionID -> (last commands, commands sequence id)
+    _webio_transports = {}  # WebIOSessionID -> ReliableTransport(), type: Dict[str, ReliableTransport]
     _webio_expire = LRUDict()  # WebIOSessionID -> last active timestamp. In increasing order of last active time
     _webio_expire_lock = threading.Lock()
 
@@ -143,23 +184,13 @@ class HttpHandler:
             if session:
                 session.close(nonblock=True)
                 del cls._webio_sessions[sid]
+                del cls._webio_transports[sid]
 
     @classmethod
     def _remove_webio_session(cls, sid):
         cls._webio_sessions.pop(sid, None)
         cls._webio_expire.pop(sid, None)
 
-    @classmethod
-    def get_response(cls, sid, ack=0):
-        commands, seq = cls._webio_last_commands.get(sid, ([], 0))
-        if ack == seq:
-            webio_session = cls._webio_sessions[sid]
-            commands = webio_session.get_task_commands()
-            seq += 1
-            cls._webio_last_commands[sid] = (commands, seq)
-
-        return {'commands': commands, 'seq': seq}
-
     def _process_cors(self, context: HttpContext):
         """Handling cross-domain requests: check the source of the request and set headers"""
         origin = context.request_headers().get('Origin', '')
@@ -209,6 +240,14 @@ class HttpHandler:
             return False
         return self.cdn
 
+    def read_event_data(self, context: HttpContext) -> List[Dict]:
+        try:
+            if context.request_headers().get('content-type') == 'application/octet-stream':
+                return [deserialize_binary_event(context.request_body())]
+            return json.loads(context.request_body())
+        except Exception:
+            return []
+
     @contextmanager
     def handle_request_context(self, context: HttpContext):
         """called when every http request"""
@@ -240,16 +279,18 @@ class HttpHandler:
             context.set_content(html)
             return context.get_response()
 
-        webio_session_id = None
+        ack = int(context.request_url_parameter('ack', 0))
+        webio_session_id = request_headers['webio-session-id']
+        new_request = False
+        if webio_session_id.startswith('NEW-'):
+            new_request = True
+            webio_session_id = webio_session_id[4:]
 
-        # 初始请求,创建新 Session
-        if not request_headers['webio-session-id'] or request_headers['webio-session-id'] == 'NEW':
+        if new_request and webio_session_id not in cls._webio_sessions:  # 初始请求,创建新 Session
             if context.request_method() == 'POST':  # 不能在POST请求中创建Session,防止CSRF攻击
                 context.set_status(403)
                 return context.get_response()
 
-            webio_session_id = random_str(24)
-            context.set_header('webio-session-id', webio_session_id)
             session_info = get_session_info_from_headers(context.request_headers())
             session_info['user_ip'] = context.get_client_ip()
             session_info['request'] = context.request_obj()
@@ -264,17 +305,23 @@ class HttpHandler:
                 session_cls = ThreadBasedSession
             webio_session = session_cls(application, session_info=session_info)
             cls._webio_sessions[webio_session_id] = webio_session
-            yield type(self).WAIT_MS_ON_POST / 1000.0  # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
-        elif request_headers['webio-session-id'] not in cls._webio_sessions:  # WebIOSession deleted
-            context.set_content([dict(command='close_session')], json_type=True)
+            cls._webio_transports[webio_session_id] = ReliableTransport(webio_session)
+            yield cls.WAIT_MS_ON_POST / 1000.0  # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
+        elif webio_session_id not in cls._webio_sessions:  # WebIOSession deleted
+            close_msg = ReliableTransport.close_message(ack)
+            context.set_content(close_msg, json_type=True)
             return context.get_response()
         else:
-            webio_session_id = request_headers['webio-session-id']
+            # in this case, the request_headers['webio-session-id'] may also startswith NEW,
+            # this is because the response for the previous new session request has not been received by the client,
+            # and the client has sent a new request with the same session id.
             webio_session = cls._webio_sessions[webio_session_id]
 
         if context.request_method() == 'POST':  # client push event
-            if context.request_json() is not None:
-                webio_session.send_client_event(context.request_json())
+            seq = int(context.request_url_parameter('seq', 0))
+            event_data = self.read_event_data(context)
+            submit_cnt = cls._webio_transports[webio_session_id].push_event(event_data, seq)
+            if submit_cnt > 0:
                 yield type(self).WAIT_MS_ON_POST / 1000.0  # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
         elif context.request_method() == 'GET':  # client pull messages
             pass
@@ -283,8 +330,8 @@ class HttpHandler:
 
         self.interval_cleaning()
 
-        ack = int(context.request_url_parameter('ack', 0))
-        context.set_content(type(self).get_response(webio_session_id, ack=ack), json_type=True)
+        resp = cls._webio_transports[webio_session_id].get_response(ack)
+        context.set_content(resp, json_type=True)
 
         if webio_session.closed():
             self._remove_webio_session(webio_session_id)

+ 80 - 46
webiojs/src/session.ts

@@ -1,4 +1,4 @@
-import {error_alert} from "./utils";
+import {error_alert, randomid, ReliableSender} from "./utils";
 import {state} from "./state";
 import {t} from "./i18n";
 
@@ -178,10 +178,11 @@ export class WebSocketSession implements Session {
 
 export class HttpSession implements Session {
     interval_pull_id: number = null;
-    webio_session_id: string = 'NEW';
+    webio_session_id: string = '';
     debug = false;
 
-    private _executed_command_msg_id = 0;
+    private sender: ReliableSender = null;
+    private _executed_command_msg_id = -1;
     private _closed = false;
     private _session_create_callbacks: (() => void)[] = [];
     private _session_close_callbacks: (() => void)[] = [];
@@ -193,6 +194,7 @@ export class HttpSession implements Session {
         let url = new URL(api_url, window.location.href);
         url.search = "?app=" + app_name;
         this.api_url = url.href;
+        this.sender = new ReliableSender(this._send.bind(this));
     }
 
     on_session_create(callback: () => void): void {
@@ -209,6 +211,7 @@ export class HttpSession implements Session {
 
     start_session(debug: boolean = false): void {
         this.debug = debug;
+        this.webio_session_id = "NEW-" + randomid(24);
         this.pull();
         this.interval_pull_id = setInterval(() => {
             this.pull()
@@ -223,74 +226,104 @@ export class HttpSession implements Session {
             contentType: "application/json; charset=utf-8",
             dataType: "json",
             headers: {"webio-session-id": this.webio_session_id},
-            success: function (data: { commands: Command[], seq: number }, textStatus: string, jqXHR: JQuery.jqXHR) {
+            success: function (data: { commands: Command[][], seq: number, event: number, ack: number },
+                               textStatus: string, jqXHR: JQuery.jqXHR) {
                 safe_poprun_callbacks(that._session_create_callbacks, 'session_create_callback');
                 that._on_request_success(data, textStatus, jqXHR);
-            },
-            error: function () {
-                console.error('Http pulling failed');
+                if (that.webio_session_id.startsWith("NEW-")) {
+                    that.webio_session_id = that.webio_session_id.substring(4);
+                }
             }
         })
     }
 
-    private _on_request_success(data: { commands: Command[], seq: number }, textStatus: string, jqXHR: JQuery.jqXHR) {
-        if (data.seq == this._executed_command_msg_id)
+    private _on_request_success(data: { commands: Command[][], seq: number, ack: number },
+                                textStatus: string, jqXHR: JQuery.jqXHR) {
+        this.sender.ack(data.ack);
+
+        let msg_start_idx = this._executed_command_msg_id - data.seq + 1;
+        if (data.commands.length <= msg_start_idx)
             return;
-        this._executed_command_msg_id = data.seq;
+        this._executed_command_msg_id = data.seq + data.commands.length - 1;
 
         let sid = jqXHR.getResponseHeader('webio-session-id');
         if (sid)
             this.webio_session_id = sid;
 
-        for (let msg of data.commands) {
-            if (this.debug) console.info('>>>', msg);
-            this._on_server_message(msg);
+        for (let msgs of data.commands.slice(msg_start_idx)) {
+            for (let msg of msgs) {
+                if (this.debug) console.info('>>>', msg);
+                this._on_server_message(msg);
+            }
         }
     };
 
     send_message(msg: ClientEvent, onprogress?: (loaded: number, total: number) => void): void {
         if (this.debug) console.info('<<<', msg);
-        this._send({
-            data: JSON.stringify(msg),
-            contentType: "application/json; charset=utf-8",
-        }, onprogress);
+        this.sender.add_send_task({
+            data: msg,
+            json: true,
+            onprogress: onprogress,
+        })
     }
 
     send_buffer(data: Blob, onprogress?: (loaded: number, total: number) => void): void {
         if (this.debug) console.info('<<< Blob data...');
-        this._send({
+        this.sender.add_send_task({
             data: data,
-            cache: false,
-            processData: false,
-            contentType: 'application/octet-stream',
-        }, onprogress);
+            json: false,
+            onprogress: onprogress,
+        }, false)
     }
 
-    _send(options: { [key: string]: any; }, onprogress?: (loaded: number, total: number) => void): void {
-        if (this.closed())
-            return error_alert(t("disconnected_with_server"));
-
-        $.ajax({
-            ...options,
-            type: "POST",
-            url: `${this.api_url}&ack=${this._executed_command_msg_id}`,
-            dataType: "json",
-            headers: {"webio-session-id": this.webio_session_id},
-            success: this._on_request_success.bind(this),
-            xhr: function () {
-                let xhr = new window.XMLHttpRequest();
-                // Upload progress
-                xhr.upload.addEventListener("progress", function (evt) {
-                    if (evt.lengthComputable && onprogress) {
-                        onprogress(evt.loaded, evt.total);
-                    }
-                }, false);
-                return xhr;
-            },
-            error: function () {
-                console.error('Http push blob data failed');
-                error_alert(t("connect_fail"));
+    _send(params: { [key: string]: any; }[], seq: number): Promise<void> {
+        if (this.closed()) {
+            this.sender.stop();
+            error_alert(t("disconnected_with_server"));
+            return Promise.reject();
+        }
+        let data: any, ajax_options: any;
+        let json = params.some(p => p.json);
+        if (json) {
+            data = JSON.stringify(params.map(p => p.data));
+            ajax_options = {
+                contentType: "application/json; charset=utf-8",
+            }
+        } else {
+            data = params[0].data;
+            ajax_options = {
+                cache: false,
+                processData: false,
+                contentType: 'application/octet-stream',
             }
+        }
+        return new Promise((resolve, reject) => {
+            $.ajax({
+                data: data,
+                ...ajax_options,
+                type: "POST",
+                url: `${this.api_url}&ack=${this._executed_command_msg_id}&seq=${seq}`,
+                dataType: "json",
+                headers: {"webio-session-id": this.webio_session_id},
+                success: this._on_request_success.bind(this),
+                xhr: function () {
+                    let xhr = new window.XMLHttpRequest();
+                    // Upload progress
+                    xhr.upload.addEventListener("progress", function (evt) {
+                        if (evt.lengthComputable) {
+                            params.forEach(p => {
+                                if (p.onprogress) // only the first one
+                                    p.onprogress(evt.loaded, evt.total);
+                                p.onprogress = null;
+                            });
+                        }
+                    }, false);
+                    return xhr;
+                },
+                error: function () {
+                    console.error('Http push event failed, will retry');
+                }
+            }).always(() => resolve());
         });
     }
 
@@ -298,6 +331,7 @@ export class HttpSession implements Session {
         this._closed = true;
         safe_poprun_callbacks(this._session_close_callbacks, 'session_close_callback');
         clearInterval(this.interval_pull_id);
+        this.sender.stop();
     }
 
     closed(): boolean {

+ 78 - 0
webiojs/src/utils.ts

@@ -183,4 +183,82 @@ export function is_mobile() {
     if (navigator.userAgentData) return navigator.userAgentData.mobile;
     const ipadOS = (navigator.platform === 'MacIntel' && navigator.maxTouchPoints > 1); /* iPad OS 13 */
     return /android|webos|iphone|ipad|ipod|blackberry|iemobile|opera mini/i.test(navigator.userAgent.toLowerCase()) || ipadOS;
+}
+
+// put send task to a queue and run it one by one
+export class ReliableSender {
+    private seq = 0;
+    private queue: { enable_batch: boolean, param: any }[] = [];
+    private send_running = false
+    private _stop = false;
+
+    constructor(
+        private readonly sender: (params: any[], seq: number) => Promise<void>,
+        private window_size: number = 8,
+        init_seq = 0, private timeout = 2000
+    ) {
+        this.sender = sender;
+        this.window_size = window_size;
+        this.timeout = timeout;
+        this.seq = init_seq;
+        this.queue = [];
+    }
+
+    /*
+    * for continuous batch_send tasks in queue, they will be sent in one sender, the sending will retry when it finished or timeout.
+    * for non-batch task, each will be sent in a single sender, the sending will retry when it finished.
+    * */
+    add_send_task(param: any, allow_batch_send = true) {
+        if (this._stop) return;
+        this.queue.push({
+            enable_batch: allow_batch_send,
+            param: param
+        });
+        if (!this.send_running)
+            this.start_send();
+    }
+
+    private start_send() {
+        if (this._stop || this.queue.length === 0) {
+            this.send_running = false;
+            return;
+        }
+        this.send_running = true;
+        let params: any[] = [];
+        for (let item of this.queue) {
+            if (!item.enable_batch)
+                break;
+            params.push(item.param);
+        }
+        let batch_send = true;
+        if (params.length === 0 && !this.queue[0].enable_batch) {
+            batch_send = false;
+            params.push(this.queue[0].param);
+        }
+        if (params.length === 0) {
+            this.send_running = false;
+            return;
+        }
+
+        let promises = [this.sender(params, this.seq)];
+        if (batch_send)
+            promises.push(new Promise((resolve) => setTimeout(resolve, this.timeout)));
+
+        Promise.race(promises).then(() => {
+            this.start_send();
+        });
+    }
+
+    // seq for each ack call must be larger than the previous one, otherwise the ack will be ignored
+    ack(seq: number) {
+        if (seq < this.seq)
+            return;
+        let pop_count = seq - this.seq + 1;
+        this.queue = this.queue.slice(pop_count);
+        this.seq = seq + 1;
+    }
+
+    stop() {
+        this._stop = true;
+    }
 }