浏览代码

use Session.active_session_count() to detect if session started

wangweimin 5 年之前
父节点
当前提交
37e6117bad
共有 4 个文件被更改,包括 44 次插入13 次删除
  1. 13 9
      pywebio/session/__init__.py
  2. 5 0
      pywebio/session/base.py
  3. 9 0
      pywebio/session/coroutinebased.py
  4. 17 4
      pywebio/session/threadbased.py

+ 13 - 9
pywebio/session/__init__.py

@@ -1,4 +1,6 @@
-import threading, asyncio, inspect
+import asyncio
+import inspect
+import threading
 from functools import wraps
 
 from .base import AbstractSession
@@ -56,22 +58,24 @@ def get_current_session() -> "AbstractSession":
     try:
         return _session_type.get_current_session()
     except SessionNotFoundException:
-        if _server_started:
+        # 如果没已经运行的backend server,在当前线程上下文作为session启动backend server
+        if get_session_implement().active_session_count() == 0:
+            _start_script_mode_server()
+            return _session_type.get_current_session()
+        else:
             raise
-        # 没有显式启动backend server时,在当前线程上下文作为session启动backend server
-        _start_script_mode_server()
-        return _session_type.get_current_session()
 
 
 def get_current_task_id():
     try:
         return _session_type.get_current_task_id()
     except RuntimeError:
-        if _server_started:
+        # 如果没已经运行的backend server,在当前线程上下文作为session启动backend server
+        if get_session_implement().active_session_count() == 0:
+            _start_script_mode_server()
+            return _session_type.get_current_session()
+        else:
             raise
-        # 没有显式启动backend server时,在当前线程上下文作为session启动backend server
-        _start_script_mode_server()
-        return _session_type.get_current_task_id()
 
 
 def check_session_impl(session_type):

+ 5 - 0
pywebio/session/base.py

@@ -19,6 +19,7 @@ class AbstractSession:
 
     Task和Backend都可调用:
         closed
+        active_session_count
 
 
     Session是不同的后端Backend与协程交互的桥梁:
@@ -26,6 +27,10 @@ class AbstractSession:
         Task内在调用输入输出函数后,会调用 ``send_task_command`` 向会话发送输入输出消息指令, Session将其保存并留给后端Backend处理。
     """
 
+    @staticmethod
+    def active_session_count() -> int:
+        raise NotImplementedError
+
     @staticmethod
     def get_current_session() -> "AbstractSession":
         raise NotImplementedError

+ 9 - 0
pywebio/session/coroutinebased.py

@@ -36,6 +36,12 @@ class CoroutineBasedSession(AbstractSession):
     当用户浏览器主动关闭会话,CoroutineBasedSession.close 被调用, 协程任务和会话内所有通过 `run_async` 注册的协程都被关闭。
     """
 
+    _active_session_cnt = 0
+
+    @classmethod
+    def active_session_count(cls):
+        return cls._active_session_cnt
+
     @staticmethod
     def get_current_session() -> "CoroutineBasedSession":
         if _context.current_session is None:
@@ -57,6 +63,8 @@ class CoroutineBasedSession(AbstractSession):
         assert asyncio.iscoroutinefunction(target) or inspect.isgeneratorfunction(target), ValueError(
             "CoroutineBasedSession accept coroutine function or generator function as task function")
 
+        CoroutineBasedSession._active_session_cnt += 1
+
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_session_close = on_session_close or (lambda: None)
         self.unhandled_task_msgs = []
@@ -120,6 +128,7 @@ class CoroutineBasedSession(AbstractSession):
         for t in self.coros.values():
             t.close()
         self.coros = {}  # delete session tasks
+        CoroutineBasedSession._active_session_cnt -= 1
 
     def close(self):
         """关闭当前Session。由Backend调用"""

+ 17 - 4
pywebio/session/threadbased.py

@@ -7,7 +7,7 @@ import threading
 import traceback
 
 from .base import AbstractSession
-from ..exceptions import SessionNotFoundException
+from ..exceptions import SessionNotFoundException, SessionClosedException
 from ..utils import random_str
 
 logger = logging.getLogger(__name__)
@@ -29,6 +29,12 @@ class ThreadBasedSession(AbstractSession):
     event_mq_maxsize = 100
     callback_mq_maxsize = 100
 
+    _active_session_cnt = 0
+
+    @classmethod
+    def active_session_count(cls):
+        return cls._active_session_cnt
+
     @classmethod
     def get_current_session(cls) -> "ThreadBasedSession":
         curr = id(threading.current_thread())
@@ -55,6 +61,12 @@ class ThreadBasedSession(AbstractSession):
         :param loop: 事件循环。若 on_task_command 或者 on_session_close 中有调用使用asyncio事件循环的调用,
             则需要事件循环实例来将回调在事件循环的线程中执行
         """
+        assert (not asyncio.iscoroutinefunction(target)) and (not inspect.isgeneratorfunction(target)), ValueError(
+            "ThreadBasedSession only accept a simple function as task function, "
+            "not coroutine function or generator function. ")
+
+        ThreadBasedSession._active_session_cnt += 1
+
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_session_close = on_session_close or (lambda: None)
         self._loop = loop
@@ -74,9 +86,6 @@ class ThreadBasedSession(AbstractSession):
         self._start_main_task(target)
 
     def _start_main_task(self, target):
-        assert (not asyncio.iscoroutinefunction(target)) and (not inspect.isgeneratorfunction(target)), ValueError(
-            "ThreadBasedSession only accept a simple function as task function, "
-            "not coroutine function or generator function. ")
 
         def thread_task(target):
             try:
@@ -156,6 +165,8 @@ class ThreadBasedSession(AbstractSession):
         if self.callback_mq is not None:  # 回调功能已经激活
             self.callback_mq.put(None)  # 结束回调线程
 
+        ThreadBasedSession._active_session_cnt -= 1
+
     def close(self):
         """关闭当前Session。由Backend调用"""
         if self._closed:
@@ -279,6 +290,8 @@ class ScriptModeSession(ThreadBasedSession):
             raise RuntimeError("ScriptModeSession can only be created once.")
         ScriptModeSession.instance = self
 
+        ThreadBasedSession._active_session_cnt += 1
+
         self._on_task_command = on_task_command or (lambda _: None)
         self._on_session_close = lambda: None
         self._loop = loop