|
@@ -4,27 +4,25 @@ import json
|
|
import logging
|
|
import logging
|
|
import os
|
|
import os
|
|
import threading
|
|
import threading
|
|
-import time
|
|
|
|
|
|
+import typing
|
|
import webbrowser
|
|
import webbrowser
|
|
from functools import partial
|
|
from functools import partial
|
|
-from typing import Dict
|
|
|
|
from urllib.parse import urlparse
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
import tornado
|
|
import tornado
|
|
import tornado.httpserver
|
|
import tornado.httpserver
|
|
import tornado.ioloop
|
|
import tornado.ioloop
|
|
-from tornado.web import StaticFileHandler
|
|
|
|
-from tornado.websocket import WebSocketHandler
|
|
|
|
|
|
+import tornado.web
|
|
|
|
+import tornado.websocket
|
|
|
|
|
|
from . import page
|
|
from . import page
|
|
-from .remote_access import start_remote_access_service
|
|
|
|
|
|
+from .adaptor import ws as ws_adaptor
|
|
from .page import make_applications, render_page
|
|
from .page import make_applications, render_page
|
|
-from .utils import cdn_validation, deserialize_binary_event, print_listen_address
|
|
|
|
-from ..session import CoroutineBasedSession, ThreadBasedSession, ScriptModeSession, \
|
|
|
|
- register_session_implement_for_target, Session
|
|
|
|
|
|
+from .remote_access import start_remote_access_service
|
|
|
|
+from .utils import cdn_validation, print_listen_address
|
|
|
|
+from ..session import ScriptModeSession, register_session_implement_for_target, Session
|
|
from ..session.base import get_session_info_from_headers
|
|
from ..session.base import get_session_info_from_headers
|
|
-from ..utils import get_free_port, wait_host_port, STATIC_PATH, iscoroutinefunction, isgeneratorfunction, \
|
|
|
|
- check_webio_js, parse_file_size, random_str, LRUDict
|
|
|
|
|
|
+from ..utils import get_free_port, wait_host_port, STATIC_PATH, check_webio_js, parse_file_size
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@@ -45,7 +43,7 @@ def ioloop() -> tornado.ioloop.IOLoop:
|
|
return _ioloop
|
|
return _ioloop
|
|
|
|
|
|
|
|
|
|
-def _check_origin(origin, allowed_origins, handler: WebSocketHandler):
|
|
|
|
|
|
+def _check_origin(origin, allowed_origins, handler: tornado.websocket.WebSocketHandler):
|
|
if _is_same_site(origin, handler):
|
|
if _is_same_site(origin, handler):
|
|
return True
|
|
return True
|
|
|
|
|
|
@@ -55,7 +53,7 @@ def _check_origin(origin, allowed_origins, handler: WebSocketHandler):
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
-def _is_same_site(origin, handler: WebSocketHandler):
|
|
|
|
|
|
+def _is_same_site(origin, handler: tornado.websocket.WebSocketHandler):
|
|
parsed_origin = urlparse(origin)
|
|
parsed_origin = urlparse(origin)
|
|
origin = parsed_origin.netloc
|
|
origin = parsed_origin.netloc
|
|
origin = origin.lower()
|
|
origin = origin.lower()
|
|
@@ -66,6 +64,32 @@ def _is_same_site(origin, handler: WebSocketHandler):
|
|
return origin == host
|
|
return origin == host
|
|
|
|
|
|
|
|
|
|
|
|
+class WebSocketConnection(ws_adaptor.WebSocketConnection):
|
|
|
|
+
|
|
|
|
+ def __init__(self, context: tornado.websocket.WebSocketHandler):
|
|
|
|
+ self.context = context
|
|
|
|
+
|
|
|
|
+ def get_query_argument(self, name) -> typing.Optional[str]:
|
|
|
|
+ return self.context.get_query_argument(name, None)
|
|
|
|
+
|
|
|
|
+ def make_session_info(self) -> dict:
|
|
|
|
+ session_info = get_session_info_from_headers(self.context.request.headers)
|
|
|
|
+ session_info['user_ip'] = self.context.request.remote_ip
|
|
|
|
+ session_info['request'] = self.context.request
|
|
|
|
+ session_info['backend'] = 'tornado'
|
|
|
|
+ session_info['protocol'] = 'websocket'
|
|
|
|
+ return session_info
|
|
|
|
+
|
|
|
|
+ def write_message(self, message: dict):
|
|
|
|
+ self.context.write_message(json.dumps(message))
|
|
|
|
+
|
|
|
|
+ def closed(self) -> bool:
|
|
|
|
+ return not bool(self.context.ws_connection)
|
|
|
|
+
|
|
|
|
+ def close(self):
|
|
|
|
+ self.context.close()
|
|
|
|
+
|
|
|
|
+
|
|
def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origin_func=_is_same_site): # noqa: C901
|
|
def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origin_func=_is_same_site): # noqa: C901
|
|
"""
|
|
"""
|
|
:param dict applications: dict of `name -> task function`
|
|
:param dict applications: dict of `name -> task function`
|
|
@@ -78,16 +102,10 @@ def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origi
|
|
if applications is None:
|
|
if applications is None:
|
|
applications = dict(index=lambda: None) # mock PyWebIO app
|
|
applications = dict(index=lambda: None) # mock PyWebIO app
|
|
|
|
|
|
- class WSHandler(WebSocketHandler):
|
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
|
- super().__init__(*args, **kwargs)
|
|
|
|
- self._close_from_session = False
|
|
|
|
- self.session_id = None
|
|
|
|
- self.session = None # type: Session
|
|
|
|
- if reconnect_timeout and not type(self)._started_clean_task:
|
|
|
|
- type(self)._started_clean_task = True
|
|
|
|
- tornado.ioloop.IOLoop.current().call_later(reconnect_timeout // 2, type(self).clean_expired_sessions)
|
|
|
|
- logger.debug("Started session clean task")
|
|
|
|
|
|
+ ws_adaptor.set_expire_second(reconnect_timeout)
|
|
|
|
+ tornado.ioloop.IOLoop.current().spawn_callback(ws_adaptor.session_clean_task)
|
|
|
|
+
|
|
|
|
+ class Handler(tornado.websocket.WebSocketHandler):
|
|
|
|
|
|
def get_app(self):
|
|
def get_app(self):
|
|
app_name = self.get_query_argument('app', 'index')
|
|
app_name = self.get_query_argument('app', 'index')
|
|
@@ -100,7 +118,7 @@ def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origi
|
|
return cdn
|
|
return cdn
|
|
|
|
|
|
async def get(self, *args, **kwargs) -> None:
|
|
async def get(self, *args, **kwargs) -> None:
|
|
- # It's a simple http GET request
|
|
|
|
|
|
+ """http GET request"""
|
|
if self.request.headers.get("Upgrade", "").lower() != "websocket":
|
|
if self.request.headers.get("Upgrade", "").lower() != "websocket":
|
|
# Backward compatible
|
|
# Backward compatible
|
|
# Frontend detect whether the backend is http server
|
|
# Frontend detect whether the backend is http server
|
|
@@ -108,7 +126,6 @@ def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origi
|
|
return self.write('')
|
|
return self.write('')
|
|
|
|
|
|
app = self.get_app()
|
|
app = self.get_app()
|
|
-
|
|
|
|
html = render_page(app, protocol='ws', cdn=self.get_cdn())
|
|
html = render_page(app, protocol='ws', cdn=self.get_cdn())
|
|
return self.write(html)
|
|
return self.write(html)
|
|
else:
|
|
else:
|
|
@@ -121,121 +138,21 @@ def _webio_handler(applications=None, cdn=True, reconnect_timeout=0, check_origi
|
|
# Non-None enables compression with default options.
|
|
# Non-None enables compression with default options.
|
|
return {}
|
|
return {}
|
|
|
|
|
|
- @classmethod
|
|
|
|
- def clean_expired_sessions(cls):
|
|
|
|
- tornado.ioloop.IOLoop.current().call_later(reconnect_timeout // 2, cls.clean_expired_sessions)
|
|
|
|
-
|
|
|
|
- while cls._session_expire:
|
|
|
|
- session_id, expire_ts = cls._session_expire.popitem(last=False) # 弹出最早过期的session
|
|
|
|
-
|
|
|
|
- if time.time() < expire_ts:
|
|
|
|
- # this session is not expired
|
|
|
|
- cls._session_expire[session_id] = expire_ts # restore this item
|
|
|
|
- cls._session_expire.move_to_end(session_id, last=False) # move to front
|
|
|
|
- break
|
|
|
|
-
|
|
|
|
- # clean this session
|
|
|
|
- logger.debug("session %s expired" % session_id)
|
|
|
|
- cls._connections.pop(session_id, None)
|
|
|
|
- session = cls._webio_sessions.pop(session_id, None)
|
|
|
|
- if session:
|
|
|
|
- session.close(nonblock=True)
|
|
|
|
-
|
|
|
|
- @classmethod
|
|
|
|
- def send_msg_to_client(cls, _, session_id=None):
|
|
|
|
- conn = cls._connections.get(session_id)
|
|
|
|
- session = cls._webio_sessions[session_id]
|
|
|
|
-
|
|
|
|
- if not conn or not conn.ws_connection:
|
|
|
|
- return
|
|
|
|
-
|
|
|
|
- for msg in session.get_task_commands():
|
|
|
|
- try:
|
|
|
|
- conn.write_message(json.dumps(msg))
|
|
|
|
- except TypeError as e:
|
|
|
|
- logger.exception('Data serialization error: %s\n'
|
|
|
|
- 'This may be because you pass the wrong type of parameter to the function'
|
|
|
|
- ' of PyWebIO.\nData content: %s', e, msg)
|
|
|
|
-
|
|
|
|
- @classmethod
|
|
|
|
- def close_from_session(cls, session_id=None):
|
|
|
|
- cls.send_msg_to_client(None, session_id=session_id)
|
|
|
|
-
|
|
|
|
- conn = cls._connections.pop(session_id, None)
|
|
|
|
- cls._webio_sessions.pop(session_id, None)
|
|
|
|
- if conn and conn.ws_connection:
|
|
|
|
- conn._close_from_session = True
|
|
|
|
- conn.close()
|
|
|
|
-
|
|
|
|
- _started_clean_task = False
|
|
|
|
- _session_expire = LRUDict() # session_id -> expire timestamp. In increasing order of expire time
|
|
|
|
- _webio_sessions = {} # type: Dict[str, Session] # session_id -> session
|
|
|
|
- _connections = {} # type: Dict[str, WSHandler] # session_id -> WSHandler
|
|
|
|
|
|
+ _handler: ws_adaptor.WebSocketHandler
|
|
|
|
|
|
def open(self):
|
|
def open(self):
|
|
- logger.debug("WebSocket opened")
|
|
|
|
- cls = type(self)
|
|
|
|
-
|
|
|
|
- self.session_id = self.get_query_argument('session', None)
|
|
|
|
- if self.session_id in ('NEW', None): # 初始请求,创建新 Session
|
|
|
|
- session_info = get_session_info_from_headers(self.request.headers)
|
|
|
|
- session_info['user_ip'] = self.request.remote_ip
|
|
|
|
- session_info['request'] = self.request
|
|
|
|
- session_info['backend'] = 'tornado'
|
|
|
|
- session_info['protocol'] = 'websocket'
|
|
|
|
-
|
|
|
|
- application = self.get_app()
|
|
|
|
- self.session_id = random_str(24)
|
|
|
|
- cls._connections[self.session_id] = self
|
|
|
|
-
|
|
|
|
- if iscoroutinefunction(application) or isgeneratorfunction(application):
|
|
|
|
- self.session = CoroutineBasedSession(
|
|
|
|
- application, session_info=session_info,
|
|
|
|
- on_task_command=partial(self.send_msg_to_client, session_id=self.session_id),
|
|
|
|
- on_session_close=partial(self.close_from_session, session_id=self.session_id))
|
|
|
|
- else:
|
|
|
|
- self.session = ThreadBasedSession(
|
|
|
|
- application, session_info=session_info,
|
|
|
|
- on_task_command=partial(self.send_msg_to_client, session_id=self.session_id),
|
|
|
|
- on_session_close=partial(self.close_from_session, session_id=self.session_id),
|
|
|
|
- loop=asyncio.get_event_loop())
|
|
|
|
- cls._webio_sessions[self.session_id] = self.session
|
|
|
|
-
|
|
|
|
- if reconnect_timeout:
|
|
|
|
- self.write_message(json.dumps(dict(command='set_session_id', spec=self.session_id)))
|
|
|
|
-
|
|
|
|
- elif self.session_id not in cls._webio_sessions: # WebIOSession deleted
|
|
|
|
- self.write_message(json.dumps(dict(command='close_session')))
|
|
|
|
- else:
|
|
|
|
- self.session = cls._webio_sessions[self.session_id]
|
|
|
|
- cls._session_expire.pop(self.session_id, None)
|
|
|
|
- cls._connections[self.session_id] = self
|
|
|
|
- cls.send_msg_to_client(self.session, self.session_id)
|
|
|
|
-
|
|
|
|
- logger.debug('session id: %s' % self.session_id)
|
|
|
|
|
|
+ conn = WebSocketConnection(self)
|
|
|
|
+ self._handler = ws_adaptor.WebSocketHandler(
|
|
|
|
+ connection=conn, application=self.get_app(), reconnectable=bool(reconnect_timeout)
|
|
|
|
+ )
|
|
|
|
|
|
def on_message(self, message):
|
|
def on_message(self, message):
|
|
- if isinstance(message, bytes):
|
|
|
|
- event = deserialize_binary_event(message)
|
|
|
|
- else:
|
|
|
|
- event = json.loads(message)
|
|
|
|
- if event is None:
|
|
|
|
- return
|
|
|
|
- self.session.send_client_event(event)
|
|
|
|
|
|
+ self._handler.send_client_data(message)
|
|
|
|
|
|
def on_close(self):
|
|
def on_close(self):
|
|
- cls = type(self)
|
|
|
|
- cls._connections.pop(self.session_id, None)
|
|
|
|
- if not reconnect_timeout and not self._close_from_session:
|
|
|
|
- self.session.close(nonblock=True)
|
|
|
|
- elif reconnect_timeout:
|
|
|
|
- if self._close_from_session:
|
|
|
|
- cls._webio_sessions.pop(self.session_id, None)
|
|
|
|
- elif self.session:
|
|
|
|
- cls._session_expire[self.session_id] = time.time() + reconnect_timeout
|
|
|
|
- logger.debug("WebSocket closed")
|
|
|
|
|
|
+ self._handler.notify_connection_lost()
|
|
|
|
|
|
- return WSHandler
|
|
|
|
|
|
+ return Handler
|
|
|
|
|
|
|
|
|
|
def webio_handler(applications, cdn=True, reconnect_timeout=0, allowed_origins=None, check_origin=None):
|
|
def webio_handler(applications, cdn=True, reconnect_timeout=0, allowed_origins=None, check_origin=None):
|
|
@@ -278,9 +195,9 @@ def _setup_server(webio_handler, port=0, host='', static_dir=None, max_buffer_si
|
|
handlers = [(r"/", webio_handler)]
|
|
handlers = [(r"/", webio_handler)]
|
|
|
|
|
|
if static_dir is not None:
|
|
if static_dir is not None:
|
|
- handlers.append((r"/static/(.*)", StaticFileHandler, {"path": static_dir}))
|
|
|
|
|
|
+ handlers.append((r"/static/(.*)", tornado.web.StaticFileHandler, {"path": static_dir}))
|
|
|
|
|
|
- handlers.append((r"/(.*)", StaticFileHandler, {"path": STATIC_PATH, 'default_filename': 'index.html'}))
|
|
|
|
|
|
+ handlers.append((r"/(.*)", tornado.web.StaticFileHandler, {"path": STATIC_PATH, 'default_filename': 'index.html'}))
|
|
|
|
|
|
app = tornado.web.Application(handlers=handlers, **tornado_app_settings)
|
|
app = tornado.web.Application(handlers=handlers, **tornado_app_settings)
|
|
# Credit: https://stackoverflow.com/questions/19074972/content-length-too-long-when-uploading-file-using-tornado
|
|
# Credit: https://stackoverflow.com/questions/19074972/content-length-too-long-when-uploading-file-using-tornado
|