瀏覽代碼

fix: partial asynchronous/generator functions are not correct detected

wangweimin 5 年之前
父節點
當前提交
efc45f4b6f
共有 4 個文件被更改,包括 25 次插入16 次删除
  1. 2 3
      pywebio/session/__init__.py
  2. 4 5
      pywebio/session/coroutinebased.py
  3. 4 6
      pywebio/session/threadbased.py
  4. 15 2
      pywebio/utils.py

+ 2 - 3
pywebio/session/__init__.py

@@ -6,8 +6,6 @@ r"""
    :members:
    :members:
 """
 """
 
 
-import asyncio
-import inspect
 import threading
 import threading
 from functools import wraps
 from functools import wraps
 
 
@@ -15,6 +13,7 @@ from .base import AbstractSession
 from .coroutinebased import CoroutineBasedSession
 from .coroutinebased import CoroutineBasedSession
 from .threadbased import ThreadBasedSession, ScriptModeSession
 from .threadbased import ThreadBasedSession, ScriptModeSession
 from ..exceptions import SessionNotFoundException
 from ..exceptions import SessionNotFoundException
+from ..utils import iscoroutinefunction, isgeneratorfunction
 
 
 # 当前进程中正在使用的会话实现的列表
 # 当前进程中正在使用的会话实现的列表
 _active_session_cls = []
 _active_session_cls = []
@@ -24,7 +23,7 @@ __all__ = ['run_async', 'run_asyncio_coroutine', 'register_thread']
 
 
 def register_session_implement_for_target(target_func):
 def register_session_implement_for_target(target_func):
     """根据target_func函数类型注册会话实现,并返回会话实现"""
     """根据target_func函数类型注册会话实现,并返回会话实现"""
-    if asyncio.iscoroutinefunction(target_func) or inspect.isgeneratorfunction(target_func):
+    if iscoroutinefunction(target_func) or isgeneratorfunction(target_func):
         cls = CoroutineBasedSession
         cls = CoroutineBasedSession
     else:
     else:
         cls = ThreadBasedSession
         cls = ThreadBasedSession

+ 4 - 5
pywebio/session/coroutinebased.py

@@ -1,5 +1,4 @@
 import asyncio
 import asyncio
-import inspect
 import logging
 import logging
 import sys
 import sys
 import threading
 import threading
@@ -8,7 +7,7 @@ from contextlib import contextmanager
 
 
 from .base import AbstractSession
 from .base import AbstractSession
 from ..exceptions import SessionNotFoundException, SessionClosedException
 from ..exceptions import SessionNotFoundException, SessionClosedException
-from ..utils import random_str
+from ..utils import random_str, isgeneratorfunction, iscoroutinefunction
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -62,7 +61,7 @@ class CoroutineBasedSession(AbstractSession):
         :param on_task_command: 由协程内发给session的消息的处理函数
         :param on_task_command: 由协程内发给session的消息的处理函数
         :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
         :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
         """
         """
-        assert asyncio.iscoroutinefunction(target) or inspect.isgeneratorfunction(target), ValueError(
+        assert iscoroutinefunction(target) or isgeneratorfunction(target), ValueError(
             "CoroutineBasedSession accept coroutine function or generator function as task function")
             "CoroutineBasedSession accept coroutine function or generator function as task function")
 
 
         CoroutineBasedSession._active_session_cnt += 1
         CoroutineBasedSession._active_session_cnt += 1
@@ -181,9 +180,9 @@ class CoroutineBasedSession(AbstractSession):
                 event = await self.next_client_event()
                 event = await self.next_client_event()
                 assert event['event'] == 'callback'
                 assert event['event'] == 'callback'
                 coro = None
                 coro = None
-                if asyncio.iscoroutinefunction(callback):
+                if iscoroutinefunction(callback):
                     coro = callback(event['data'])
                     coro = callback(event['data'])
-                elif inspect.isgeneratorfunction(callback):
+                elif isgeneratorfunction(callback):
                     coro = asyncio.coroutine(callback)(event['data'])
                     coro = asyncio.coroutine(callback)(event['data'])
                 else:
                 else:
                     try:
                     try:

+ 4 - 6
pywebio/session/threadbased.py

@@ -1,5 +1,3 @@
-import asyncio
-import inspect
 import logging
 import logging
 import queue
 import queue
 import sys
 import sys
@@ -9,7 +7,7 @@ from functools import wraps
 
 
 from .base import AbstractSession
 from .base import AbstractSession
 from ..exceptions import SessionNotFoundException, SessionClosedException
 from ..exceptions import SessionNotFoundException, SessionClosedException
-from ..utils import random_str, LimitedSizeQueue
+from ..utils import random_str, LimitedSizeQueue, isgeneratorfunction, iscoroutinefunction
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -64,7 +62,7 @@ class ThreadBasedSession(AbstractSession):
         :param loop: 事件循环。若 on_task_command 或者 on_session_close 中有调用使用asyncio事件循环的调用,
         :param loop: 事件循环。若 on_task_command 或者 on_session_close 中有调用使用asyncio事件循环的调用,
             则需要事件循环实例来将回调在事件循环的线程中执行
             则需要事件循环实例来将回调在事件循环的线程中执行
         """
         """
-        assert (not asyncio.iscoroutinefunction(target)) and (not inspect.isgeneratorfunction(target)), ValueError(
+        assert (not iscoroutinefunction(target)) and (not isgeneratorfunction(target)), ValueError(
             "ThreadBasedSession only accept a simple function as task function, "
             "ThreadBasedSession only accept a simple function as task function, "
             "not coroutine function or generator function. ")
             "not coroutine function or generator function. ")
 
 
@@ -183,7 +181,7 @@ class ThreadBasedSession(AbstractSession):
 
 
     def on_task_exception(self):
     def on_task_exception(self):
         from ..output import put_markdown  # todo
         from ..output import put_markdown  # todo
-        logger.exception('Error in coroutine executing')
+        logger.exception('Error in thread executing')
         type, value, tb = sys.exc_info()
         type, value, tb = sys.exc_info()
         tb_len = len(list(traceback.walk_tb(tb)))
         tb_len = len(list(traceback.walk_tb(tb)))
         lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
         lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
@@ -248,7 +246,7 @@ class ThreadBasedSession(AbstractSession):
 
 
         :param bool serial_mode: 串行模式模式。若为 ``True`` ,则对于同一组件的点击事件,串行执行其回调函数
         :param bool serial_mode: 串行模式模式。若为 ``True`` ,则对于同一组件的点击事件,串行执行其回调函数
         """
         """
-        assert (not asyncio.iscoroutinefunction(callback)) and (not inspect.isgeneratorfunction(callback)), ValueError(
+        assert (not iscoroutinefunction(callback)) and (not isgeneratorfunction(callback)), ValueError(
             "In ThreadBasedSession.register_callback, `callback` must be a simple function, "
             "In ThreadBasedSession.register_callback, `callback` must be a simple function, "
             "not coroutine function or generator function. ")
             "not coroutine function or generator function. ")
 
 

+ 15 - 2
pywebio/utils.py

@@ -1,12 +1,13 @@
 import asyncio
 import asyncio
+import functools
+import inspect
+import queue
 import random
 import random
 import socket
 import socket
 import string
 import string
 import time
 import time
 from collections import OrderedDict
 from collections import OrderedDict
 from contextlib import closing
 from contextlib import closing
-import queue
-
 from os.path import abspath, dirname
 from os.path import abspath, dirname
 
 
 project_dir = dirname(abspath(__file__))
 project_dir = dirname(abspath(__file__))
@@ -14,6 +15,18 @@ project_dir = dirname(abspath(__file__))
 STATIC_PATH = '%s/html' % project_dir
 STATIC_PATH = '%s/html' % project_dir
 
 
 
 
+def iscoroutinefunction(object):
+    while isinstance(object, functools.partial):
+        object = object.func
+    return asyncio.iscoroutinefunction(object)
+
+
+def isgeneratorfunction(object):
+    while isinstance(object, functools.partial):
+        object = object.func
+    return inspect.isgeneratorfunction(object)
+
+
 class LimitedSizeQueue(queue.Queue):
 class LimitedSizeQueue(queue.Queue):
     """
     """
     有限大小的队列
     有限大小的队列