Browse Source

pass lambda calling async func in event_handlers

Rodja Trappe 1 year ago
parent
commit
78d7c11e70
3 changed files with 26 additions and 12 deletions
  1. 17 11
      nicegui/events.py
  2. 1 1
      nicegui/helpers.py
  3. 8 0
      tests/test_events.py

+ 17 - 11
nicegui/events.py

@@ -1,6 +1,8 @@
 from dataclasses import dataclass
 from inspect import Parameter, signature
-from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Awaitable, BinaryIO, Callable, Dict, List, Optional, Union
+
+from nicegui.slot import Slot
 
 from . import background_tasks, globals
 from .helpers import KWONLY_SLOTS, is_coroutine
@@ -268,24 +270,28 @@ class KeyEventArguments(EventArguments):
     modifiers: KeyboardModifiers
 
 
+def run_coroutine(result: Awaitable, name: str, slot: Slot):
+    async def wait_for_result():
+        with slot:
+            await result
+    if globals.loop and globals.loop.is_running():
+        background_tasks.create(wait_for_result(), name=name)
+    else:
+        globals.app.on_startup(wait_for_result())
+
+
 def handle_event(handler: Optional[Callable[..., Any]],
                  arguments: Union[EventArguments, Dict], *,
                  sender: Optional['Element'] = None) -> None:
+    if handler is None:
+        return
     try:
-        if handler is None:
-            return
         no_arguments = not any(p.default is Parameter.empty for p in signature(handler).parameters.values())
         sender = arguments.sender if isinstance(arguments, EventArguments) else sender
         assert sender is not None and sender.parent_slot is not None
         with sender.parent_slot:
             result = handler() if no_arguments else handler(arguments)
-        if is_coroutine(handler):
-            async def wait_for_result():
-                with sender.parent_slot:
-                    await result
-            if globals.loop and globals.loop.is_running():
-                background_tasks.create(wait_for_result(), name=str(handler))
-            else:
-                globals.app.on_startup(wait_for_result())
+        if is_coroutine(handler) or is_coroutine(result):
+            run_coroutine(result, str(handler), sender.parent_slot)
     except Exception as e:
         globals.handle_exception(e)

+ 1 - 1
nicegui/helpers.py

@@ -20,7 +20,7 @@ KWONLY_SLOTS = {'kw_only': True, 'slots': True} if sys.version_info >= (3, 10) e
 def is_coroutine(object: Any) -> bool:
     while isinstance(object, functools.partial):
         object = object.func
-    return asyncio.iscoroutinefunction(object)
+    return asyncio.iscoroutinefunction(object) or asyncio.iscoroutine(object)
 
 
 def safe_invoke(func: Union[Callable[..., Any], Awaitable], client: Optional['Client'] = None) -> None:

+ 8 - 0
tests/test_events.py

@@ -26,21 +26,29 @@ async def click_async_with_args(_: ClickEventArguments):
     ui.label('click_async_with_args')
 
 
+async def click_lambda_with_async_and_parameters(msg: str):
+    await asyncio.sleep(0.1)
+    ui.label(f'click_lambda_with_async_and_parameters: {msg}')
+
+
 def test_click_events(screen: Screen):
     ui.button('click_sync_no_args', on_click=click_sync_no_args)
     ui.button('click_sync_with_args', on_click=click_sync_with_args)
     ui.button('click_async_no_args', on_click=click_async_no_args)
     ui.button('click_async_with_args', on_click=click_async_with_args)
+    ui.button('click_lambda_with_async_and_parameters', on_click=lambda: click_lambda_with_async_and_parameters('works'))
 
     screen.open('/')
     screen.click('click_sync_no_args')
     screen.click('click_sync_with_args')
     screen.click('click_async_no_args')
     screen.click('click_async_with_args')
+    screen.click('click_lambda_with_async_and_parameters')
     screen.should_contain('click_sync_no_args')
     screen.should_contain('click_sync_with_args')
     screen.should_contain('click_async_no_args')
     screen.should_contain('click_async_with_args')
+    screen.should_contain('click_lambda_with_async_and_parameters: works')
 
 
 def test_generic_events(screen: Screen):