1
0
Эх сурвалжийг харах

fix: partial asynchronous/generator functions are not correct detected

wangweimin 5 жил өмнө
parent
commit
efc45f4b6f

+ 2 - 3
pywebio/session/__init__.py

@@ -6,8 +6,6 @@ r"""
    :members:
 """
 
-import asyncio
-import inspect
 import threading
 from functools import wraps
 
@@ -15,6 +13,7 @@ from .base import AbstractSession
 from .coroutinebased import CoroutineBasedSession
 from .threadbased import ThreadBasedSession, ScriptModeSession
 from ..exceptions import SessionNotFoundException
+from ..utils import iscoroutinefunction, isgeneratorfunction
 
 # 当前进程中正在使用的会话实现的列表
 _active_session_cls = []
@@ -24,7 +23,7 @@ __all__ = ['run_async', 'run_asyncio_coroutine', 'register_thread']
 
 def register_session_implement_for_target(target_func):
     """根据target_func函数类型注册会话实现,并返回会话实现"""
-    if asyncio.iscoroutinefunction(target_func) or inspect.isgeneratorfunction(target_func):
+    if iscoroutinefunction(target_func) or isgeneratorfunction(target_func):
         cls = CoroutineBasedSession
     else:
         cls = ThreadBasedSession

+ 4 - 5
pywebio/session/coroutinebased.py

@@ -1,5 +1,4 @@
 import asyncio
-import inspect
 import logging
 import sys
 import threading
@@ -8,7 +7,7 @@ from contextlib import contextmanager
 
 from .base import AbstractSession
 from ..exceptions import SessionNotFoundException, SessionClosedException
-from ..utils import random_str
+from ..utils import random_str, isgeneratorfunction, iscoroutinefunction
 
 logger = logging.getLogger(__name__)
 
@@ -62,7 +61,7 @@ class CoroutineBasedSession(AbstractSession):
         :param on_task_command: 由协程内发给session的消息的处理函数
         :param on_session_close: 会话结束的处理函数。后端Backend在相应on_session_close时关闭连接时,需要保证会话内的所有消息都传送到了客户端
         """
-        assert asyncio.iscoroutinefunction(target) or inspect.isgeneratorfunction(target), ValueError(
+        assert iscoroutinefunction(target) or isgeneratorfunction(target), ValueError(
             "CoroutineBasedSession accept coroutine function or generator function as task function")
 
         CoroutineBasedSession._active_session_cnt += 1
@@ -181,9 +180,9 @@ class CoroutineBasedSession(AbstractSession):
                 event = await self.next_client_event()
                 assert event['event'] == 'callback'
                 coro = None
-                if asyncio.iscoroutinefunction(callback):
+                if iscoroutinefunction(callback):
                     coro = callback(event['data'])
-                elif inspect.isgeneratorfunction(callback):
+                elif isgeneratorfunction(callback):
                     coro = asyncio.coroutine(callback)(event['data'])
                 else:
                     try:

+ 4 - 6
pywebio/session/threadbased.py

@@ -1,5 +1,3 @@
-import asyncio
-import inspect
 import logging
 import queue
 import sys
@@ -9,7 +7,7 @@ from functools import wraps
 
 from .base import AbstractSession
 from ..exceptions import SessionNotFoundException, SessionClosedException
-from ..utils import random_str, LimitedSizeQueue
+from ..utils import random_str, LimitedSizeQueue, isgeneratorfunction, iscoroutinefunction
 
 logger = logging.getLogger(__name__)
 
@@ -64,7 +62,7 @@ class ThreadBasedSession(AbstractSession):
         :param loop: 事件循环。若 on_task_command 或者 on_session_close 中有调用使用asyncio事件循环的调用,
             则需要事件循环实例来将回调在事件循环的线程中执行
         """
-        assert (not asyncio.iscoroutinefunction(target)) and (not inspect.isgeneratorfunction(target)), ValueError(
+        assert (not iscoroutinefunction(target)) and (not isgeneratorfunction(target)), ValueError(
             "ThreadBasedSession only accept a simple function as task function, "
             "not coroutine function or generator function. ")
 
@@ -183,7 +181,7 @@ class ThreadBasedSession(AbstractSession):
 
     def on_task_exception(self):
         from ..output import put_markdown  # todo
-        logger.exception('Error in coroutine executing')
+        logger.exception('Error in thread executing')
         type, value, tb = sys.exc_info()
         tb_len = len(list(traceback.walk_tb(tb)))
         lines = traceback.format_exception(type, value, tb, limit=1 - tb_len)
@@ -248,7 +246,7 @@ class ThreadBasedSession(AbstractSession):
 
         :param bool serial_mode: 串行模式模式。若为 ``True`` ,则对于同一组件的点击事件,串行执行其回调函数
         """
-        assert (not asyncio.iscoroutinefunction(callback)) and (not inspect.isgeneratorfunction(callback)), ValueError(
+        assert (not iscoroutinefunction(callback)) and (not isgeneratorfunction(callback)), ValueError(
             "In ThreadBasedSession.register_callback, `callback` must be a simple function, "
             "not coroutine function or generator function. ")
 

+ 15 - 2
pywebio/utils.py

@@ -1,12 +1,13 @@
 import asyncio
+import functools
+import inspect
+import queue
 import random
 import socket
 import string
 import time
 from collections import OrderedDict
 from contextlib import closing
-import queue
-
 from os.path import abspath, dirname
 
 project_dir = dirname(abspath(__file__))
@@ -14,6 +15,18 @@ project_dir = dirname(abspath(__file__))
 STATIC_PATH = '%s/html' % project_dir
 
 
+def iscoroutinefunction(object):
+    while isinstance(object, functools.partial):
+        object = object.func
+    return asyncio.iscoroutinefunction(object)
+
+
+def isgeneratorfunction(object):
+    while isinstance(object, functools.partial):
+        object = object.func
+    return inspect.isgeneratorfunction(object)
+
+
 class LimitedSizeQueue(queue.Queue):
     """
     有限大小的队列