Procházet zdrojové kódy

use helpers.is_coroutine_function to avoid awaiting AwaitableResponses by accident

Falko Schindler před 1 rokem
rodič
revize
edc5182c85

+ 3 - 3
examples/single_page_app/router.py

@@ -1,6 +1,6 @@
-from typing import Awaitable, Callable, Dict, Union
+from typing import Callable, Dict, Union
 
-from nicegui import background_tasks, ui
+from nicegui import background_tasks, helpers, ui
 
 
 class RouterFrame(ui.element, component='router_frame.js'):
@@ -35,7 +35,7 @@ class Router():
                     }}
                 ''')
                 result = builder()
-                if isinstance(result, Awaitable):
+                if helpers.is_coroutine_function(builder):
                     await result
         self.content.clear()
         background_tasks.create(build())

+ 3 - 3
nicegui/events.py

@@ -3,9 +3,9 @@ from __future__ import annotations
 from contextlib import nullcontext
 from dataclasses import dataclass
 from inspect import Parameter, signature
-from typing import TYPE_CHECKING, Any, Awaitable, BinaryIO, Callable, Dict, List, Literal, Optional, Union
+from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, List, Literal, Optional, Union
 
-from . import background_tasks, globals  # pylint: disable=redefined-builtin
+from . import background_tasks, globals, helpers  # pylint: disable=redefined-builtin
 from .dataclasses import KWONLY_SLOTS
 from .slot import Slot
 
@@ -432,7 +432,7 @@ def handle_event(handler: Optional[Callable[..., Any]], arguments: EventArgument
 
         with parent_slot:
             result = handler(arguments) if expects_arguments else handler()
-        if isinstance(result, Awaitable):
+        if helpers.is_coroutine_function(handler):
             async def wait_for_result():
                 with parent_slot:
                     try:

+ 3 - 3
nicegui/functions/timer.py

@@ -1,8 +1,8 @@
 import asyncio
 import time
-from typing import Any, Awaitable, Callable, Optional
+from typing import Any, Callable, Optional
 
-from .. import background_tasks, globals  # pylint: disable=redefined-builtin
+from .. import background_tasks, globals, helpers  # pylint: disable=redefined-builtin
 from ..binding import BindableProperty
 from ..slot import Slot
 
@@ -90,7 +90,7 @@ class Timer:
         try:
             assert self.callback is not None
             result = self.callback()
-            if isinstance(result, Awaitable):
+            if helpers.is_coroutine_function(self.callback):
                 await result
         except Exception as e:
             globals.handle_exception(e)

+ 2 - 1
nicegui/globals.py

@@ -13,6 +13,7 @@ from socketio import AsyncServer
 from uvicorn import Server
 
 from . import background_tasks
+from .helpers import is_coroutine_function
 
 if TYPE_CHECKING:
     from .air import Air
@@ -123,5 +124,5 @@ def handle_exception(exception: Exception) -> None:
     """Handle an exception by invoking all registered exception handlers."""
     for handler in exception_handlers:
         result = handler() if not inspect.signature(handler).parameters else handler(exception)
-        if isinstance(result, Awaitable):
+        if is_coroutine_function(handler):
             background_tasks.create(result)

+ 1 - 1
nicegui/helpers.py

@@ -72,7 +72,7 @@ def safe_invoke(func: Union[Callable[..., Any], Awaitable], client: Optional[Cli
         else:
             with client or nullcontext():
                 result = func(client) if len(inspect.signature(func).parameters) == 1 and client is not None else func()
-            if isinstance(result, Awaitable):
+            if is_coroutine_function(func):
                 async def result_with_client():
                     with client or nullcontext():
                         await result

+ 2 - 2
nicegui/page.py

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
 
 from fastapi import Request, Response
 
-from . import background_tasks, binding, globals  # pylint: disable=redefined-builtin
+from . import background_tasks, binding, globals, helpers  # pylint: disable=redefined-builtin
 from .client import Client
 from .favicon import create_favicon_route
 from .language import Language
@@ -95,7 +95,7 @@ class page:
                 if any(p.name == 'client' for p in inspect.signature(func).parameters.values()):
                     dec_kwargs['client'] = client
                 result = func(*dec_args, **dec_kwargs)
-            if inspect.isawaitable(result):
+            if helpers.is_coroutine_function(func):
                 async def wait_for_result() -> None:
                     with client:
                         return await result