Bläddra i källkod

docs: code comment

wangweimin 5 år sedan
förälder
incheckning
88fbe8a8ba
2 ändrade filer med 12 tillägg och 8 borttagningar
  1. 6 3
      pywebio/session/coroutinebased.py
  2. 6 5
      pywebio/session/threadbased.py

+ 6 - 3
pywebio/session/coroutinebased.py

@@ -52,12 +52,15 @@ class CoroutineBasedSession(AbstractSession):
             raise RuntimeError("No current task found in context!")
         return _context.current_task_id
 
-    def __init__(self, coroutine_func, on_task_command=None, on_session_close=None):
+    def __init__(self, target, on_task_command=None, on_session_close=None):
         """
-        :param coroutine_func: 协程函数
+        :param target: 协程函数
         :param on_task_command: 由协程内发给session的消息的处理函数
         :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
         """
+        assert asyncio.iscoroutinefunction(target) or inspect.isgeneratorfunction(target), ValueError(
+            "In CoroutineBasedSession accept coroutine function or generator function as task function")
+
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_session_close = on_session_close or (lambda: None)
         self.unhandled_task_msgs = []
@@ -67,7 +70,7 @@ class CoroutineBasedSession(AbstractSession):
         self._closed = False
         self.inactive_coro_instances = []  # 待激活的协程实例列表
 
-        self.main_task = Task(coroutine_func(), session=self, on_coro_stop=self._on_main_task_finish)
+        self.main_task = Task(target(), session=self, on_coro_stop=self._on_main_task_finish)
         self.coros[self.main_task.coro_id] = self.main_task
 
         self._step_task(self.main_task)

+ 6 - 5
pywebio/session/threadbased.py

@@ -34,7 +34,7 @@ class ThreadBasedSession(AbstractSession):
         curr = threading.current_thread().getName()
         session = cls.thread2session.get(curr)
         if session is None:
-            raise SessionNotFoundException("Can't find current session. Maybe session closed.")
+            raise SessionNotFoundException("Can't find current session. Maybe session closed. Did you forget to use `register_thread` ?")
         return session
 
     @staticmethod
@@ -43,8 +43,8 @@ class ThreadBasedSession(AbstractSession):
 
     def __init__(self, target, on_task_command=None, on_session_close=None, loop=None):
         """
-        :param target_func: 会话运行的函数
-        :param on_coro_msg: 由协程内发给session的消息的处理函数
+        :param target: 会话运行的函数
+        :param on_task_command: 当Task内发送Command给session的时候触发的处理函数
         :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,
             需要保证会话内的所有消息都传送到了客户端
         :param loop: 事件循环。若 on_task_command 或者 on_session_close 中有调用使用asyncio事件循环的调用,
@@ -58,7 +58,7 @@ class ThreadBasedSession(AbstractSession):
         self.threads = []  # 当前会话的线程id集合,用户会话结束后,清理数据
         self.unhandled_task_msgs = []
 
-        self.event_mqs = {}  # thread_id -> event msg queue
+        self.event_mqs = {}  # task_id -> event msg queue
         self._closed = False
 
         # 用于实现回调函数的注册
@@ -197,7 +197,8 @@ class ThreadBasedSession(AbstractSession):
                 try:
                     callback(event['data'])
                 except:
-                    ThreadBasedSession.get_current_session().on_task_exception()
+                    # 子类可能会重写 get_current_session ,所以不要用 ThreadBasedSession.get_current_session 来调用
+                    self.get_current_session().on_task_exception()
 
             if mutex:
                 run(callback)