Explorar o código

auto start server in script mode when not explict invoke `start_server`

wangweimin %!s(int64=5) %!d(string=hai) anos
pai
achega
f728eaa65c

+ 18 - 0
pywebio/exceptions.py

@@ -0,0 +1,18 @@
+"""
+pywebio.exceptions
+~~~~~~~~~~~~~~~~~~~
+
+This module contains the set of PyWebIO's exceptions.
+"""
+
+
+class SessionException(Exception):
+    pass
+
+
+class SessionClosedException(SessionException):
+    pass
+
+
+class SessionNotFoundException(SessionException):
+    pass

+ 1 - 7
pywebio/platform/__init__.py

@@ -1,9 +1,3 @@
-from os.path import abspath, dirname
-
-project_dir = dirname(dirname(abspath(__file__)))
-
-STATIC_PATH = '%s/html' % project_dir
-
-from .tornado import start_server
+from .tornado import start_server, start_server_in_current_thread_session
 
 __all__ = ['start_server']

+ 4 - 3
pywebio/platform/flask.py

@@ -27,8 +27,8 @@ from typing import Dict
 
 from flask import Flask, request, jsonify, send_from_directory
 
-from . import STATIC_PATH
-from ..session import AsyncBasedSession, ThreadBasedWebIOSession, get_session_implement, AbstractSession
+from ..session import AsyncBasedSession, ThreadBasedWebIOSession, get_session_implement, AbstractSession, mark_server_started
+from ..utils import STATIC_PATH
 from ..utils import random_str, LRUDict
 
 # todo: use lock to avoid thread race condition
@@ -131,7 +131,6 @@ def start_flask_server(coro_func, port=8080, host='localhost', disable_asyncio=F
                        session_expire_seconds=DEFAULT_SESSION_EXPIRE_SECONDS,
                        debug=False, **flask_options):
     """
-
     :param coro_func:
     :param port:
     :param host:
@@ -143,6 +142,8 @@ def start_flask_server(coro_func, port=8080, host='localhost', disable_asyncio=F
     :param flask_options:
     :return:
     """
+    mark_server_started()
+
     app = Flask(__name__)
     app.route('/io', methods=['GET', 'POST'])(webio_view(coro_func, session_expire_seconds))
 

+ 79 - 28
pywebio/platform/tornado.py

@@ -1,6 +1,7 @@
 import asyncio
 import json
 import logging
+import threading
 import webbrowser
 
 import tornado
@@ -8,9 +9,9 @@ import tornado.httpserver
 import tornado.ioloop
 import tornado.websocket
 from tornado.web import StaticFileHandler
-from . import STATIC_PATH
-from ..session import AsyncBasedSession, ThreadBasedWebIOSession, get_session_implement
-from ..utils import get_free_port, wait_host_port
+from ..session import AsyncBasedSession, ThreadBasedWebIOSession, get_session_implement, DesignatedThreadSession, \
+    mark_server_started
+from ..utils import get_free_port, wait_host_port, STATIC_PATH
 
 logger = logging.getLogger(__name__)
 
@@ -25,35 +26,35 @@ def webio_handler(task_func):
             # Non-None enables compression with default options.
             return {}
 
-        def send_msg_to_client(self, controller: AsyncBasedSession):
-            for msg in controller.get_task_messages():
+        def send_msg_to_client(self, session: AsyncBasedSession):
+            for msg in session.get_task_messages():
                 self.write_message(json.dumps(msg))
 
         def open(self):
             logger.debug("WebSocket opened")
             self.set_nodelay(True)
 
-            self._close_from_session = False  # 是否从session中关闭连接
+            self._close_from_session_tag = False  # 是否从session中关闭连接
 
             if get_session_implement() is AsyncBasedSession:
-                self.controller = AsyncBasedSession(task_func, on_task_message=self.send_msg_to_client,
-                                                    on_session_close=self.close)
+                self.session = AsyncBasedSession(task_func, on_task_message=self.send_msg_to_client,
+                                                 on_session_close=self.close)
             else:
-                self.controller = ThreadBasedWebIOSession(task_func, on_task_message=self.send_msg_to_client,
-                                                          on_session_close=self.close_from_session,
-                                                          loop=asyncio.get_event_loop())
+                self.session = ThreadBasedWebIOSession(task_func, on_task_message=self.send_msg_to_client,
+                                                       on_session_close=self.close_from_session,
+                                                       loop=asyncio.get_event_loop())
 
         def on_message(self, message):
             data = json.loads(message)
-            self.controller.send_client_event(data)
+            self.session.send_client_event(data)
 
         def close_from_session(self):
-            self._close_from_session = True
+            self._close_from_session_tag = True
             self.close()
 
         def on_close(self):
-            if not self._close_from_session:
-                self.controller.close(no_session_close_callback=True)
+            if not self._close_from_session_tag:
+                self.session.close(no_session_close_callback=True)
             logger.debug("WebSocket closed")
 
     return WSHandler
@@ -69,8 +70,21 @@ async def open_webbrowser_on_server_started(host, port):
         logger.error('Open %s failed.' % url)
 
 
-def start_server(target, port=0, host='', debug=True,
-                 auto_open_webbrowser=False,
+def _setup_server(webio_handler, port=0, host='', **tornado_app_settings):
+    if port == 0:
+        port = get_free_port()
+
+    print('Listen on %s:%s' % (host or '0.0.0.0', port))
+
+    handlers = [(r"/io", webio_handler),
+                (r"/(.*)", StaticFileHandler, {"path": STATIC_PATH, 'default_filename': 'index.html'})]
+
+    app = tornado.web.Application(handlers=handlers, **tornado_app_settings)
+    server = app.listen(port, address=host)
+    return server, port
+
+
+def start_server(target, port=0, host='', debug=False,
                  websocket_max_message_size=None,
                  websocket_ping_interval=None,
                  websocket_ping_timeout=None,
@@ -84,7 +98,6 @@ def start_server(target, port=0, host='', debug=True,
         the server will listen on all IP addresses associated with the name.
         set empty string or to listen on all available interfaces.
     :param bool debug: Tornado debug mode
-    :param bool auto_open_webbrowser: auto open web browser when server started
     :param int websocket_max_message_size: Max bytes of a message which Tornado can accept.
         Messages larger than the ``websocket_max_message_size`` (default 10MiB) will not be accepted.
     :param int websocket_ping_interval: If set to a number, all websockets will be pinged every n seconds.
@@ -98,22 +111,60 @@ def start_server(target, port=0, host='', debug=True,
     :return:
     """
     kwargs = locals()
+
+    mark_server_started()
+
     app_options = ['debug', 'websocket_max_message_size', 'websocket_ping_interval', 'websocket_ping_timeout']
     for opt in app_options:
         if kwargs[opt] is not None:
             tornado_app_settings[opt] = kwargs[opt]
 
-    if port == 0:
-        port = get_free_port()
+    handler = webio_handler(target)
+    _setup_server(webio_handler=handler, port=port, host=host, **tornado_app_settings)
+    tornado.ioloop.IOLoop.current().start()
 
-    print('Listen on %s:%s' % (host or '0.0.0.0', port))
 
-    handlers = [(r"/io", webio_handler(target)),
-                (r"/(.*)", StaticFileHandler, {"path": STATIC_PATH, 'default_filename': 'index.html'})]
+def start_server_in_current_thread_session():
+    mark_server_started()
 
-    app = tornado.web.Application(handlers=handlers, **tornado_app_settings)
-    app.listen(port, address=host)
+    websocket_conn_opened = threading.Event()
+    thread = threading.current_thread()
 
-    if auto_open_webbrowser:
-        tornado.ioloop.IOLoop.current().spawn_callback(open_webbrowser_on_server_started, host or '0.0.0.0', port)
-    tornado.ioloop.IOLoop.current().start()
+    class SingletonWSHandler(webio_handler(None)):
+        session = None
+
+        def open(self):
+            if SingletonWSHandler.session is None:
+                SingletonWSHandler.session = DesignatedThreadSession(thread, on_task_message=self.send_msg_to_client,
+                                                                   loop=asyncio.get_event_loop())
+                websocket_conn_opened.set()
+            else:
+                self.close()
+
+        def on_close(self):
+            if SingletonWSHandler.session is not None:
+                self.session.close()
+                logger.debug('DesignatedThreadSession.closed')
+
+    async def stoploop_after_thread_stop(thread: threading.Thread):
+        while thread.is_alive():
+            await asyncio.sleep(1)
+        await asyncio.sleep(1)
+        logger.debug('Thread[%s] exit. Closing tornado ioloop...', thread.getName())
+        tornado.ioloop.IOLoop.current().stop()
+
+    def server_thread(task_thread):
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+
+        server, port = _setup_server(webio_handler=SingletonWSHandler, host='localhost')
+        tornado.ioloop.IOLoop.current().spawn_callback(stoploop_after_thread_stop, task_thread)
+        tornado.ioloop.IOLoop.current().spawn_callback(open_webbrowser_on_server_started, 'localhost', port)
+
+        tornado.ioloop.IOLoop.current().start()
+        logger.debug('Tornado server exit')
+
+    t = threading.Thread(target=server_thread, args=(threading.current_thread(),), name='Tornado-server')
+    t.start()
+
+    websocket_conn_opened.wait()

+ 36 - 6
pywebio/session/__init__.py

@@ -1,18 +1,27 @@
 import threading
+from functools import wraps
 
 from .asyncbased import AsyncBasedSession
 from .base import AbstractSession
-from .threadbased import ThreadBasedWebIOSession
-from functools import wraps
+from .threadbased import ThreadBasedWebIOSession, DesignatedThreadSession
+from ..exceptions import SessionNotFoundException
 
 _session_type = AsyncBasedSession
 
 __all__ = ['set_session_implement', 'run_async', 'asyncio_coroutine', 'register_thread']
 
+_server_started = False
+
+
+def mark_server_started():
+    """标记服务端已经启动"""
+    global _server_started
+    _server_started = True
+
 
 def set_session_implement(session_type):
     global _session_type
-    assert session_type in [ThreadBasedWebIOSession, AsyncBasedSession]
+    assert session_type in [ThreadBasedWebIOSession, AsyncBasedSession, DesignatedThreadSession]
     _session_type = session_type
 
 
@@ -21,12 +30,32 @@ def get_session_implement():
     return _session_type
 
 
+def _start_script_mode_server():
+    from ..platform import start_server_in_current_thread_session
+    set_session_implement(DesignatedThreadSession)
+    start_server_in_current_thread_session()
+
+
 def get_current_session() -> "AbstractSession":
-    return _session_type.get_current_session()
+    try:
+        return _session_type.get_current_session()
+    except SessionNotFoundException:
+        if _server_started:
+            raise
+        # 没有显式启动backend server时,在当前线程上下文作为session启动backend server
+        _start_script_mode_server()
+        return _session_type.get_current_session()
 
 
 def get_current_task_id():
-    return _session_type.get_current_task_id()
+    try:
+        return _session_type.get_current_task_id()
+    except RuntimeError:
+        if _server_started:
+            raise
+        # 没有显式启动backend server时,在当前线程上下文作为session启动backend server
+        _start_script_mode_server()
+        return _session_type.get_current_task_id()
 
 
 def check_session_impl(session_type):
@@ -34,7 +63,8 @@ def check_session_impl(session_type):
         @wraps(func)
         def inner(*args, **kwargs):
             now_impl = get_session_implement()
-            if now_impl is not session_type:
+            if not issubclass(now_impl,
+                              session_type):  # Check if 'now_impl' is a derived from session_type or is the same class
                 func_name = getattr(func, '__name__', str(func))
                 require = getattr(session_type, '__name__', str(session_type))
                 now = getattr(now_impl, '__name__', str(now_impl))

+ 2 - 1
pywebio/session/asyncbased.py

@@ -6,6 +6,7 @@ import traceback
 from contextlib import contextmanager
 
 from .base import AbstractSession
+from ..exceptions import SessionNotFoundException
 from ..utils import random_str
 
 logger = logging.getLogger(__name__)
@@ -42,7 +43,7 @@ class AsyncBasedSession(AbstractSession):
     @staticmethod
     def get_current_session() -> "AsyncBasedSession":
         if _context.current_session is None:
-            raise RuntimeError("No current found in context!")
+            raise SessionNotFoundException("No current found in context!")
         return _context.current_session
 
     @staticmethod

+ 41 - 2
pywebio/session/threadbased.py

@@ -1,10 +1,13 @@
+import asyncio
+import inspect
 import logging
 import queue
 import sys
 import threading
 import traceback
-import asyncio, inspect
+
 from .base import AbstractSession
+from ..exceptions import SessionNotFoundException
 from ..utils import random_str
 
 logger = logging.getLogger(__name__)
@@ -31,7 +34,7 @@ class ThreadBasedWebIOSession(AbstractSession):
         curr = threading.current_thread().getName()
         session = cls.thread2session.get(curr)
         if session is None:
-            raise RuntimeError("Can't find current session. Maybe session closed.")
+            raise SessionNotFoundException("Can't find current session. Maybe session closed.")
         return session
 
     @staticmethod
@@ -44,6 +47,8 @@ class ThreadBasedWebIOSession(AbstractSession):
         :param on_coro_msg: 由协程内发给session的消息的处理函数
         :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
             需要保证会话内的所有消息都传送到了客户端
+        :param loop: 事件循环。若on_task_message或者on_session_close中有调用使用asyncio事件循环的调用,
+            则需要事件循环实例来将回调在事件循环的线程中执行
         """
         self._on_task_message = on_task_message or (lambda _: None)
         self._on_session_close = on_session_close or (lambda: None)
@@ -64,6 +69,9 @@ class ThreadBasedWebIOSession(AbstractSession):
         self._start_main_task(target)
 
     def _start_main_task(self, target):
+        assert (not asyncio.iscoroutinefunction(target)) and (not inspect.isgeneratorfunction(target)), ValueError(
+            "In ThreadBasedWebIOSession.__init__, `target` must be a simple function, "
+            "not coroutine function or generator function. ")
 
         def thread_task(target):
             try:
@@ -229,3 +237,34 @@ class ThreadBasedWebIOSession(AbstractSession):
         self.thread2session[tname] = self
         event_mq = queue.Queue(maxsize=self.event_mq_maxsize)
         self.event_mqs[tname] = event_mq
+
+
+class DesignatedThreadSession(ThreadBasedWebIOSession):
+    """以指定进程为会话"""
+
+    def __init__(self, thread, on_task_message=None, loop=None):
+        """
+        :param on_coro_msg: 由协程内发给session的消息的处理函数
+        :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
+            需要保证会话内的所有消息都传送到了客户端
+        :param loop: 事件循环。若on_task_message或者on_session_close中有调用使用asyncio事件循环的调用,
+            则需要事件循环实例来将回调在事件循环的线程中执行
+
+        """
+        self._on_task_message = on_task_message or (lambda _: None)
+        self._on_session_close = lambda: None
+        self._loop = loop
+
+        self._server_msg_lock = threading.Lock()
+        self.threads = []  # 当前会话的线程id集合,用户会话结束后,清理数据
+        self.unhandled_task_msgs = []
+
+        self.event_mqs = {}  # thread_id -> event msg queue
+        self._closed = False
+
+        # 用于实现回调函数的注册
+        self.callback_mq = None
+        self.callback_thread = None
+        self.callbacks = {}  # callback_id -> (callback_func, is_mutex)
+
+        self.register_thread(thread, as_daemon=False)

+ 6 - 0
pywebio/utils.py

@@ -6,6 +6,12 @@ import time
 from collections import OrderedDict
 from contextlib import closing
 
+from os.path import abspath, dirname
+
+project_dir = dirname(abspath(__file__))
+
+STATIC_PATH = '%s/html' % project_dir
+
 
 async def wait_host_port(host, port, duration=10, delay=2):
     """Repeatedly try if a port on a host is open until duration seconds passed