|
@@ -0,0 +1,230 @@
|
|
|
|
+import logging
|
|
|
|
+import queue
|
|
|
|
+import sys
|
|
|
|
+import threading
|
|
|
|
+import traceback
|
|
|
|
+import asyncio, inspect
|
|
|
|
+from .base import AbstractSession
|
|
|
|
+from ..utils import random_str
|
|
|
|
+
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
+
|
|
|
|
+"""
|
|
|
|
+基于线程的会话实现
|
|
|
|
+
|
|
|
|
+主任务线程退出后,连接关闭,但不会清理主任务线程产生的其他线程
|
|
|
|
+
|
|
|
|
+客户端连接关闭后,后端线程不会退出,但是再次调用
|
|
|
|
+todo: thread 重名
|
|
|
|
+"""
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# todo 线程安全
|
|
|
|
+class ThreadBasedWebIOSession(AbstractSession):
|
|
|
|
+ thread2session = {} # thread_id -> session
|
|
|
|
+
|
|
|
|
+ event_mq_maxsize = 100
|
|
|
|
+ callback_mq_maxsize = 100
|
|
|
|
+
|
|
|
|
+ @classmethod
|
|
|
|
+ def get_current_session(cls) -> "ThreadBasedWebIOSession":
|
|
|
|
+ curr = threading.current_thread().getName()
|
|
|
|
+ session = cls.thread2session.get(curr)
|
|
|
|
+ if session is None:
|
|
|
|
+ raise RuntimeError("Can't find current session. Maybe session closed.")
|
|
|
|
+ return session
|
|
|
|
+
|
|
|
|
+ @staticmethod
|
|
|
|
+ def get_current_task_id():
|
|
|
|
+ return threading.current_thread().getName()
|
|
|
|
+
|
|
|
|
+ def __init__(self, target, on_task_message=None, on_session_close=None, loop=None):
|
|
|
|
+ """
|
|
|
|
+ :param target_func: 会话运行的函数
|
|
|
|
+ :param on_coro_msg: 由协程内发给session的消息的处理函数
|
|
|
|
+ :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
|
|
|
|
+ 需要保证会话内的所有消息都传送到了客户端
|
|
|
|
+ """
|
|
|
|
+ self._on_task_message = on_task_message or (lambda _: None)
|
|
|
|
+ self._on_session_close = on_session_close or (lambda: None)
|
|
|
|
+ self._loop = loop
|
|
|
|
+
|
|
|
|
+ self._server_msg_lock = threading.Lock()
|
|
|
|
+ self.threads = [] # 当前会话的线程id集合,用户会话结束后,清理数据
|
|
|
|
+ self.unhandled_task_msgs = []
|
|
|
|
+
|
|
|
|
+ self.event_mqs = {} # thread_id -> event msg queue
|
|
|
|
+ self._closed = False
|
|
|
|
+
|
|
|
|
+ # 用于实现回调函数的注册
|
|
|
|
+ self.callback_mq = None
|
|
|
|
+ self.callback_thread = None
|
|
|
|
+ self.callbacks = {} # callback_id -> (callback_func, is_mutex)
|
|
|
|
+
|
|
|
|
+ main_task = self._new_thread_task(target, on_close=self.close)
|
|
|
|
+ self.register_thread(main_task)
|
|
|
|
+
|
|
|
|
+ main_task.start()
|
|
|
|
+
|
|
|
|
+ def _new_thread_task(self, target, on_close=None):
|
|
|
|
+
|
|
|
|
+ def thread_task(target):
|
|
|
|
+ try:
|
|
|
|
+ target()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ self.on_task_exception()
|
|
|
|
+ finally:
|
|
|
|
+ if on_close:
|
|
|
|
+ on_close()
|
|
|
|
+
|
|
|
|
+ task_name = '%s-%s' % (target.__name__, random_str(10))
|
|
|
|
+ thread = threading.Thread(target=thread_task, kwargs=dict(target=target),
|
|
|
|
+ daemon=True, name=task_name)
|
|
|
|
+
|
|
|
|
+ return thread
|
|
|
|
+
|
|
|
|
+ def send_task_message(self, message):
|
|
|
|
+ """向会话发送来自协程内的消息
|
|
|
|
+
|
|
|
|
+ :param dict message: 消息
|
|
|
|
+ """
|
|
|
|
+ with self._server_msg_lock:
|
|
|
|
+ self.unhandled_task_msgs.append(message)
|
|
|
|
+ if self._loop:
|
|
|
|
+ self._loop.call_soon_threadsafe(self._on_task_message, self)
|
|
|
|
+ else:
|
|
|
|
+ self._on_task_message(self)
|
|
|
|
+
|
|
|
|
+ def next_client_event(self):
|
|
|
|
+ name = threading.current_thread().getName()
|
|
|
|
+ event_mq = self.get_current_session().event_mqs.get(name)
|
|
|
|
+ return event_mq.get()
|
|
|
|
+
|
|
|
|
+ def send_client_event(self, event):
|
|
|
|
+ """向会话发送来自用户浏览器的事件️
|
|
|
|
+
|
|
|
|
+ :param dict event: 事件️消息
|
|
|
|
+ """
|
|
|
|
+ task_id = event['coro_id']
|
|
|
|
+ mq = self.event_mqs.get(task_id)
|
|
|
|
+ if not mq and task_id in self.callbacks:
|
|
|
|
+ mq = self.callback_mq
|
|
|
|
+
|
|
|
|
+ if not mq:
|
|
|
|
+ logger.error('event_mqs not found, task_id:%s', task_id)
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ mq.put(event)
|
|
|
|
+
|
|
|
|
+ def get_task_messages(self):
|
|
|
|
+ with self._server_msg_lock:
|
|
|
|
+ msgs = self.unhandled_task_msgs
|
|
|
|
+ self.unhandled_task_msgs = []
|
|
|
|
+ return msgs
|
|
|
|
+
|
|
|
|
+ def _cleanup(self):
|
|
|
|
+ self.event_mqs = {}
|
|
|
|
+ self.unhandled_task_msgs = []
|
|
|
|
+ for t in self.threads:
|
|
|
|
+ del ThreadBasedWebIOSession.thread2session[t]
|
|
|
|
+ # pass
|
|
|
|
+
|
|
|
|
+ if self.callback_mq is not None: # 回调功能已经激活
|
|
|
|
+ self.callback_mq.put(None) # 结束回调线程
|
|
|
|
+
|
|
|
|
+ def close(self, no_session_close_callback=False):
|
|
|
|
+ """关闭当前Session
|
|
|
|
+
|
|
|
|
+ :param bool no_session_close_callback: 不调用 on_session_close 会话结束的处理函数。
|
|
|
|
+ 当 close 是由后端Backend调用时可能希望开启 no_session_close_callback
|
|
|
|
+ """
|
|
|
|
+ self._cleanup()
|
|
|
|
+ self._closed = True
|
|
|
|
+ if not no_session_close_callback:
|
|
|
|
+ if self._loop:
|
|
|
|
+ self._loop.call_soon_threadsafe(self._on_session_close)
|
|
|
|
+ else:
|
|
|
|
+ self._on_session_close()
|
|
|
|
+
|
|
|
|
+ def closed(self):
|
|
|
|
+ return self._closed
|
|
|
|
+
|
|
|
|
+ def on_task_exception(self):
|
|
|
|
+ from ..output import put_markdown # todo
|
|
|
|
+ logger.exception('Error in coroutine executing')
|
|
|
|
+ type, value, tb = sys.exc_info()
|
|
|
|
+ tb_len = len(list(traceback.walk_tb(tb)))
|
|
|
|
+ lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
|
|
|
|
+ traceback_msg = ''.join(lines)
|
|
|
|
+ put_markdown("发生错误:\n```\n%s\n```" % traceback_msg)
|
|
|
|
+
|
|
|
|
+ def _activate_callback_env(self):
|
|
|
|
+ """激活回调功能
|
|
|
|
+
|
|
|
|
+ ThreadBasedWebIOSession的回调实现原理是:创建一个单独的线程用于接收回调事件,进而调用相关的回调函数。
|
|
|
|
+ 当用户Task中并没有使用到回调功能时,不必开启此线程,可以节省资源
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ if self.callback_mq is not None: # 回调功能已经激活
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ self.callback_mq = queue.Queue(maxsize=self.callback_mq_maxsize)
|
|
|
|
+ self.callback_thread = threading.Thread(target=self._dispatch_callback_event,
|
|
|
|
+ daemon=True, name='callback-' + random_str(10))
|
|
|
|
+ self.register_thread(self.callback_thread)
|
|
|
|
+ self.callback_thread.start()
|
|
|
|
+ logger.debug('Callback thread start')
|
|
|
|
+
|
|
|
|
+ def _dispatch_callback_event(self):
|
|
|
|
+ while not self.closed():
|
|
|
|
+ event = self.callback_mq.get()
|
|
|
|
+ if event is None: # 结束信号
|
|
|
|
+ break
|
|
|
|
+ callback_info = self.callbacks.get(event['coro_id'])
|
|
|
|
+ if not callback_info:
|
|
|
|
+ logger.error("No callback for coro_id:%s", event['coro_id'])
|
|
|
|
+ return
|
|
|
|
+ callback, mutex = callback_info
|
|
|
|
+
|
|
|
|
+ def run(callback):
|
|
|
|
+ try:
|
|
|
|
+ callback(event['data'])
|
|
|
|
+ except:
|
|
|
|
+ ThreadBasedWebIOSession.get_current_session().on_task_exception()
|
|
|
|
+
|
|
|
|
+ if mutex:
|
|
|
|
+ run(callback)
|
|
|
|
+ else:
|
|
|
|
+ t = threading.Thread(target=run, kwargs=dict(callback=callback),
|
|
|
|
+ daemon=True, name=event['coro_id'])
|
|
|
|
+ self.register_thread(t)
|
|
|
|
+ t.start()
|
|
|
|
+
|
|
|
|
+ def register_callback(self, callback, serial_mode):
|
|
|
|
+ """ 向Session注册一个回调函数,返回回调id
|
|
|
|
+
|
|
|
|
+ Session需要保证当收到前端发送的事件消息 ``{event: "callback",coro_id: 回调id, data:...}`` 时,
|
|
|
|
+ ``callback`` 回调函数被执行, 并传入事件消息中的 ``data`` 字段值作为参数
|
|
|
|
+
|
|
|
|
+ :param bool serial_mode: 串行模式模式。若为 ``True`` ,则对于同一组件的点击事件,串行执行其回调函数
|
|
|
|
+ """
|
|
|
|
+ assert (not asyncio.iscoroutinefunction(callback)) and (not inspect.isgeneratorfunction(callback)), ValueError(
|
|
|
|
+ "In ThreadBasedWebIOSession.register_callback, `callback` must be a simple function, "
|
|
|
|
+ "not coroutine function or generator function. ")
|
|
|
|
+
|
|
|
|
+ self._activate_callback_env()
|
|
|
|
+ callback_id = 'CB-%s-%s' % (getattr(callback, '__name__', ''), random_str(10))
|
|
|
|
+ self.callbacks[callback_id] = (callback, serial_mode)
|
|
|
|
+ return callback_id
|
|
|
|
+
|
|
|
|
+ def register_thread(self, t: threading.Thread, as_daemon=True):
|
|
|
|
+ """注册线程,以便在线程内调用 pywebio 交互函数"""
|
|
|
|
+ if as_daemon:
|
|
|
|
+ t.setDaemon(True)
|
|
|
|
+ tname = t.getName()
|
|
|
|
+ self.threads.append(tname)
|
|
|
|
+ self.thread2session[tname] = self
|
|
|
|
+ event_mq = queue.Queue(maxsize=self.event_mq_maxsize)
|
|
|
|
+ self.event_mqs[tname] = event_mq
|
|
|
|
+
|
|
|
|
+ return event_mq
|