|
@@ -1,5 +1,6 @@
|
|
import asyncio
|
|
import asyncio
|
|
import functools
|
|
import functools
|
|
|
|
+import inspect
|
|
from contextlib import nullcontext
|
|
from contextlib import nullcontext
|
|
from typing import Any, Awaitable, Callable, List, Optional, Union
|
|
from typing import Any, Awaitable, Callable, List, Optional, Union
|
|
|
|
|
|
@@ -14,7 +15,7 @@ def is_coroutine(object: Any) -> bool:
|
|
return asyncio.iscoroutinefunction(object)
|
|
return asyncio.iscoroutinefunction(object)
|
|
|
|
|
|
|
|
|
|
-def safe_invoke(func: Union[Callable, Awaitable], client: Optional[Client] = None, *args: List[Any]) -> None:
|
|
|
|
|
|
+def safe_invoke(func: Union[Callable, Awaitable], client: Optional[Client] = None) -> None:
|
|
try:
|
|
try:
|
|
if isinstance(func, Awaitable):
|
|
if isinstance(func, Awaitable):
|
|
async def func_with_client():
|
|
async def func_with_client():
|
|
@@ -23,7 +24,7 @@ def safe_invoke(func: Union[Callable, Awaitable], client: Optional[Client] = Non
|
|
create_task(func_with_client())
|
|
create_task(func_with_client())
|
|
else:
|
|
else:
|
|
with client or nullcontext():
|
|
with client or nullcontext():
|
|
- result = func(*args)
|
|
|
|
|
|
+ result = func(client) if len(inspect.signature(func).parameters) == 1 and client is not None else func()
|
|
if isinstance(result, Awaitable):
|
|
if isinstance(result, Awaitable):
|
|
async def result_with_client():
|
|
async def result_with_client():
|
|
with client or nullcontext():
|
|
with client or nullcontext():
|