Ver Fonte

Merge pull request #305 from zauberzeug/generic_events

Generic events
Rodja Trappe há 2 anos atrás
pai
commit
382ee8bae4
5 ficheiros alterados com 74 adições e 14 exclusões
  1. 3 4
      nicegui/element.py
  2. 10 7
      nicegui/events.py
  3. 5 3
      nicegui/helpers.py
  4. 1 0
      nicegui/ui.py
  5. 55 0
      tests/test_events.py

+ 3 - 4
nicegui/element.py

@@ -3,11 +3,12 @@ from __future__ import annotations
 import shlex
 import shlex
 from abc import ABC
 from abc import ABC
 from copy import deepcopy
 from copy import deepcopy
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
 
 
 from . import background_tasks, binding, globals
 from . import background_tasks, binding, globals
 from .elements.mixins.visibility import Visibility
 from .elements.mixins.visibility import Visibility
 from .event_listener import EventListener
 from .event_listener import EventListener
+from .events import handle_event
 from .slot import Slot
 from .slot import Slot
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -158,9 +159,7 @@ class Element(ABC, Visibility):
     def handle_event(self, msg: Dict) -> None:
     def handle_event(self, msg: Dict) -> None:
         for listener in self._event_listeners:
         for listener in self._event_listeners:
             if listener.type == msg['type']:
             if listener.type == msg['type']:
-                result = listener.handler(msg)
-                if isinstance(result, Awaitable):
-                    background_tasks.create(result)
+                handle_event(listener.handler, msg, sender=self)
 
 
     def collect_descendant_ids(self) -> List[int]:
     def collect_descendant_ids(self) -> List[int]:
         '''includes own ID as first element'''
         '''includes own ID as first element'''

+ 10 - 7
nicegui/events.py

@@ -1,21 +1,21 @@
 import traceback
 import traceback
 from dataclasses import dataclass
 from dataclasses import dataclass
 from inspect import signature
 from inspect import signature
-from typing import TYPE_CHECKING, Any, BinaryIO, Callable, List, Optional
+from typing import TYPE_CHECKING, Any, BinaryIO, Callable, List, Optional, Union
 
 
 from . import background_tasks, globals
 from . import background_tasks, globals
 from .async_updater import AsyncUpdater
 from .async_updater import AsyncUpdater
-from .client import Client
 from .helpers import is_coroutine
 from .helpers import is_coroutine
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
+    from .client import Client
     from .element import Element
     from .element import Element
 
 
 
 
 @dataclass
 @dataclass
 class EventArguments:
 class EventArguments:
     sender: 'Element'
     sender: 'Element'
-    client: Client
+    client: 'Client'
 
 
 
 
 @dataclass
 @dataclass
@@ -259,17 +259,20 @@ class KeyEventArguments(EventArguments):
     modifiers: KeyboardModifiers
     modifiers: KeyboardModifiers
 
 
 
 
-def handle_event(handler: Optional[Callable], arguments: EventArguments) -> None:
+def handle_event(handler: Optional[Callable],
+                 arguments: Union[EventArguments, dict], *,
+                 sender: Optional['Element'] = None) -> None:
     try:
     try:
         if handler is None:
         if handler is None:
             return
             return
         no_arguments = not signature(handler).parameters
         no_arguments = not signature(handler).parameters
-        assert arguments.sender.parent_slot is not None
-        with arguments.sender.parent_slot:
+        sender = arguments.sender if isinstance(arguments, EventArguments) else sender
+        assert sender.parent_slot is not None
+        with sender.parent_slot:
             result = handler() if no_arguments else handler(arguments)
             result = handler() if no_arguments else handler(arguments)
         if is_coroutine(handler):
         if is_coroutine(handler):
             async def wait_for_result():
             async def wait_for_result():
-                with arguments.sender.parent_slot:
+                with sender.parent_slot:
                     await AsyncUpdater(result)
                     await AsyncUpdater(result)
             if globals.loop and globals.loop.is_running():
             if globals.loop and globals.loop.is_running():
                 background_tasks.create(wait_for_result(), name=str(handler))
                 background_tasks.create(wait_for_result(), name=str(handler))

+ 5 - 3
nicegui/helpers.py

@@ -2,10 +2,12 @@ import asyncio
 import functools
 import functools
 import inspect
 import inspect
 from contextlib import nullcontext
 from contextlib import nullcontext
-from typing import Any, Awaitable, Callable, Optional, Union
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union
 
 
 from . import background_tasks, globals
 from . import background_tasks, globals
-from .client import Client
+
+if TYPE_CHECKING:
+    from .client import Client
 
 
 
 
 def is_coroutine(object: Any) -> bool:
 def is_coroutine(object: Any) -> bool:
@@ -14,7 +16,7 @@ def is_coroutine(object: Any) -> bool:
     return asyncio.iscoroutinefunction(object)
     return asyncio.iscoroutinefunction(object)
 
 
 
 
-def safe_invoke(func: Union[Callable, Awaitable], client: Optional[Client] = None) -> None:
+def safe_invoke(func: Union[Callable, Awaitable], client: Optional['Client'] = None) -> None:
     try:
     try:
         if isinstance(func, Awaitable):
         if isinstance(func, Awaitable):
             async def func_with_client():
             async def func_with_client():

+ 1 - 0
nicegui/ui.py

@@ -1,5 +1,6 @@
 import os
 import os
 
 
+from .element import Element as element
 from .elements.audio import Audio as audio
 from .elements.audio import Audio as audio
 from .elements.badge import Badge as badge
 from .elements.badge import Badge as badge
 from .elements.button import Button as button
 from .elements.button import Button as button

+ 55 - 0
tests/test_events.py

@@ -3,10 +3,65 @@ import asyncio
 from selenium.webdriver.common.by import By
 from selenium.webdriver.common.by import By
 
 
 from nicegui import ui
 from nicegui import ui
+from nicegui.events import ClickEventArguments
 
 
 from .screen import Screen
 from .screen import Screen
 
 
 
 
+def click_sync_no_args():
+    ui.label('click_sync_no_args')
+
+
+def click_sync_with_args(_: ClickEventArguments):
+    ui.label('click_sync_with_args')
+
+
+async def click_async_no_args():
+    await asyncio.sleep(0.1)
+    ui.label('click_async_no_args')
+
+
+async def click_async_with_args(_: ClickEventArguments):
+    await asyncio.sleep(0.1)
+    ui.label('click_async_with_args')
+
+
+def test_click_events(screen: Screen):
+    ui.button('click_sync_no_args', on_click=click_sync_no_args)
+    ui.button('click_sync_with_args', on_click=click_sync_with_args)
+    ui.button('click_async_no_args', on_click=click_async_no_args)
+    ui.button('click_async_with_args', on_click=click_async_with_args)
+
+    screen.open('/')
+    screen.click('click_sync_no_args')
+    screen.click('click_sync_with_args')
+    screen.click('click_async_no_args')
+    screen.click('click_async_with_args')
+    screen.wait(0.5)
+    screen.should_contain('click_sync_no_args')
+    screen.should_contain('click_sync_with_args')
+    screen.should_contain('click_async_no_args')
+    screen.should_contain('click_async_with_args')
+
+
+def test_generic_events(screen: Screen):
+    ui.label('click_sync_no_args').on('click', click_sync_no_args)
+    ui.label('click_sync_with_args').on('click', click_sync_with_args)
+    ui.label('click_async_no_args').on('click', click_async_no_args)
+    ui.label('click_async_with_args').on('click', click_async_with_args)
+
+    screen.open('/')
+    screen.click('click_sync_no_args')
+    screen.click('click_sync_with_args')
+    screen.click('click_async_no_args')
+    screen.click('click_async_with_args')
+    screen.wait(0.5)
+    screen.should_contain('click_sync_no_args')
+    screen.should_contain('click_sync_with_args')
+    screen.should_contain('click_async_no_args')
+    screen.should_contain('click_async_with_args')
+
+
 def test_event_with_update_before_await(screen: Screen):
 def test_event_with_update_before_await(screen: Screen):
     @ui.page('/')
     @ui.page('/')
     def page():
     def page():