Quellcode durchsuchen

Merge branch 'partials' into main

Falko Schindler vor 3 Jahren
Ursprung
Commit
bff7c1fffc
5 geänderte Dateien mit 17 neuen und 5 gelöschten Zeilen
  1. 2 1
      nicegui/elements/page.py
  2. 2 2
      nicegui/events.py
  3. 9 0
      nicegui/helpers.py
  4. 2 1
      nicegui/routes.py
  5. 2 1
      nicegui/timer.py

+ 2 - 1
nicegui/elements/page.py

@@ -6,6 +6,7 @@ from pygments.formatters import HtmlFormatter
 from starlette.requests import Request
 
 from ..globals import config, page_stack, view_stack
+from ..helpers import is_coroutine
 
 
 class Page(jp.QuasarPage):
@@ -50,7 +51,7 @@ class Page(jp.QuasarPage):
     async def _route_function(self, request: Request):
         if self.on_connect:
             arg_count = len(inspect.signature(self.on_connect).parameters)
-            is_coro = inspect.iscoroutinefunction(self.on_connect)
+            is_coro = is_coroutine(self.on_connect)
             if arg_count == 1:
                 await self.on_connect(request) if is_coro else self.on_connect(request)
             elif arg_count == 0:

+ 2 - 2
nicegui/events.py

@@ -1,4 +1,3 @@
-import asyncio
 import traceback
 from inspect import signature
 from typing import Any, Callable, List, Optional
@@ -8,6 +7,7 @@ from pydantic import BaseModel
 from starlette.websockets import WebSocket
 
 from .elements.element import Element
+from .helpers import is_coroutine
 from .task_logger import create_task
 
 
@@ -222,7 +222,7 @@ def handle_event(handler: Optional[Callable], arguments: EventArguments, *,
             return False
         no_arguments = not signature(handler).parameters
         result = handler() if no_arguments else handler(arguments)
-        if asyncio.iscoroutinefunction(handler):
+        if is_coroutine(handler):
             async def async_handler():
                 try:
                     await result

+ 9 - 0
nicegui/helpers.py

@@ -1,5 +1,8 @@
+import asyncio
+import functools
 import inspect
 import time
+from typing import Any
 
 
 def measure(*, reset: bool = False, ms: bool = False):
@@ -12,3 +15,9 @@ def measure(*, reset: bool = False, ms: bool = False):
     if reset:
         print('------------', flush=True)
     t = time.time()
+
+
+def is_coroutine(object: Any) -> bool:
+    while isinstance(object, functools.partial):
+        object = object.func
+    return asyncio.iscoroutinefunction(object)

+ 2 - 1
nicegui/routes.py

@@ -4,6 +4,7 @@ from functools import wraps
 from starlette import requests, routing
 
 from . import globals
+from .helpers import is_coroutine
 
 
 def add_route(self, route):
@@ -38,7 +39,7 @@ def get(self, path: str):
                     args[key] = complex(args[key])
             if 'request' in parameters and 'request' not in args:
                 args['request'] = request
-            return await func(**args) if inspect.iscoroutinefunction(func) else func(**args)
+            return await func(**args) if is_coroutine(func) else func(**args)
         self.add_route(routing.Route(path, decorated))
         return decorated
     return decorator

+ 2 - 1
nicegui/timer.py

@@ -6,6 +6,7 @@ from typing import Callable, List
 
 from .binding import BindableProperty
 from .globals import tasks, view_stack
+from .helpers import is_coroutine
 from .task_logger import create_task
 
 NamedCoroutine = namedtuple('NamedCoroutine', ['name', 'coro'])
@@ -34,7 +35,7 @@ class Timer:
 
         async def do_callback():
             try:
-                if asyncio.iscoroutinefunction(callback):
+                if is_coroutine(callback):
                     return await callback()
                 else:
                     return callback()