Explorar el Código

Merge pull request #1969 from zauberzeug/refreshable-types

Improve type hints for ui.refreshable
Rodja Trappe hace 1 año
padre
commit
9ff8427d3b
Se han modificado 1 ficheros con 15 adiciones y 10 borrados
  1. 15 10
      nicegui/functions/refreshable.py

+ 15 - 10
nicegui/functions/refreshable.py

@@ -1,9 +1,9 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
-from typing import Any, Awaitable, Callable, ClassVar, Dict, List, Optional, Tuple, Union, cast
+from typing import Any, Awaitable, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast
 
 
-from typing_extensions import Self
+from typing_extensions import ParamSpec, Self
 
 
 from .. import background_tasks, core
 from .. import background_tasks, core
 from ..client import Client
 from ..client import Client
@@ -11,6 +11,9 @@ from ..dataclasses import KWONLY_SLOTS
 from ..element import Element
 from ..element import Element
 from ..helpers import is_coroutine_function
 from ..helpers import is_coroutine_function
 
 
+_T = TypeVar('_T')
+_P = ParamSpec('_P')
+
 
 
 @dataclass(**KWONLY_SLOTS)
 @dataclass(**KWONLY_SLOTS)
 class RefreshableTarget:
 class RefreshableTarget:
@@ -24,7 +27,7 @@ class RefreshableTarget:
     locals: List[Any] = field(default_factory=list)
     locals: List[Any] = field(default_factory=list)
     next_index: int = 0
     next_index: int = 0
 
 
-    def run(self, func: Callable[..., Any]) -> Union[Any, Awaitable]:
+    def run(self, func: Callable[..., Union[_T, Awaitable[_T]]]) -> Union[_T, Awaitable[_T]]:
         """Run the function and return the result."""
         """Run the function and return the result."""
         RefreshableTarget.current_target = self
         RefreshableTarget.current_target = self
         self.next_index = 0
         self.next_index = 0
@@ -33,9 +36,11 @@ class RefreshableTarget:
             async def wait_for_result() -> Any:
             async def wait_for_result() -> Any:
                 with self.container:
                 with self.container:
                     if self.instance is None:
                     if self.instance is None:
-                        return await func(*self.args, **self.kwargs)
+                        result = func(*self.args, **self.kwargs)
                     else:
                     else:
-                        return await func(self.instance, *self.args, **self.kwargs)
+                        result = func(self.instance, *self.args, **self.kwargs)
+                    assert isinstance(result, Awaitable)
+                    return await result
             return wait_for_result()
             return wait_for_result()
         else:
         else:
             with self.container:
             with self.container:
@@ -49,9 +54,9 @@ class RefreshableContainer(Element, component='refreshable.js'):
     pass
     pass
 
 
 
 
-class refreshable:
+class refreshable(Generic[_P, _T]):
 
 
-    def __init__(self, func: Callable[..., Any]) -> None:
+    def __init__(self, func: Callable[_P, Union[_T, Awaitable[_T]]]) -> None:
         """Refreshable UI functions
         """Refreshable UI functions
 
 
         The `@ui.refreshable` decorator allows you to create functions that have a `refresh` method.
         The `@ui.refreshable` decorator allows you to create functions that have a `refresh` method.
@@ -74,14 +79,14 @@ class refreshable:
             return refresh
             return refresh
         return attribute
         return attribute
 
 
-    def __call__(self, *args: Any, **kwargs: Any) -> Union[Any, Awaitable]:
+    def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Union[_T, Awaitable[_T]]:
         self.prune()
         self.prune()
         target = RefreshableTarget(container=RefreshableContainer(), refreshable=self, instance=self.instance,
         target = RefreshableTarget(container=RefreshableContainer(), refreshable=self, instance=self.instance,
                                    args=args, kwargs=kwargs)
                                    args=args, kwargs=kwargs)
         self.targets.append(target)
         self.targets.append(target)
         return target.run(self.func)
         return target.run(self.func)
 
 
-    def refresh(self, *args: Any, **kwargs: Any) -> None:
+    def refresh(self, *args: _P.args, **kwargs: _P.kwargs) -> None:
         """Refresh the UI elements created by this function."""
         """Refresh the UI elements created by this function."""
         self.prune()
         self.prune()
         for target in self.targets:
         for target in self.targets:
@@ -100,7 +105,7 @@ class refreshable:
                                     'either as positional or as keyword argument') from e
                                     'either as positional or as keyword argument') from e
                 raise
                 raise
             if is_coroutine_function(self.func):
             if is_coroutine_function(self.func):
-                assert result is not None
+                assert isinstance(result, Awaitable)
                 if core.loop and core.loop.is_running():
                 if core.loop and core.loop.is_running():
                     background_tasks.create(result)
                     background_tasks.create(result)
                 else:
                 else: