Prechádzať zdrojové kódy

cleanup background_tasks and introduce create_lazy

Falko Schindler 2 rokov pred
rodič
commit
843ecac4f9
1 zmenil súbory, kde vykonal 35 pridanie a 40 odobranie
  1. 35 40
      nicegui/background_tasks.py

+ 35 - 40
nicegui/background_tasks.py

@@ -1,63 +1,58 @@
-'''original copied from https://quantlane.com/blog/ensure-asyncio-task-exceptions-get-logged/'''
-
+'''inspired from https://quantlane.com/blog/ensure-asyncio-task-exceptions-get-logged/'''
 import asyncio
-import functools
 import logging
 import sys
-from typing import Any, Awaitable, Optional, Tuple, TypeVar
+from typing import Awaitable, Dict, Set, TypeVar
 
 from . import globals
 
 T = TypeVar('T')
 
+name_supported = sys.version_info[1] >= 8
+
 logger = logging.getLogger(__name__)
 
-running_tasks = set()
+running_tasks: Set[asyncio.Task] = set()
+lazy_tasks_running: Dict[str, asyncio.Task] = {}
+lazy_tasks_waiting: Dict[str, Awaitable[T]] = {}
 
 
-def create(
-    coroutine: Awaitable[T],
-    *,
-    loop: Optional[asyncio.AbstractEventLoop] = None,
-    name: str = 'unnamed task',
-) -> 'asyncio.Task[T]':  # This type annotation has to be quoted for Python < 3.9, see https://www.python.org/dev/peps/pep-0585/
-    '''
-    This helper function wraps a ``loop.create_task(coroutine())`` call and ensures there is
-    an exception handler added to the resulting task. If the task raises an exception it is logged
-    using the provided ``logger``, with additional context provided by ``message`` and optionally
-    ``message_args``.
+def create(coroutine: Awaitable[T], *, name: str = 'unnamed task') -> 'asyncio.Task[T]':
+    '''Wraps a loop.create_task call and ensures there is an exception handler added to the task.
+
+    If the task raises an exception it is logged using a ``logger``.
     Also a reference to the task is kept until it is done, so that the task is not garbage collected mid-execution.
     See https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task.
     '''
-    message = 'Task raised an exception'
-    message_args = ()
-    if loop is None:
-        loop = globals.loop
-        assert loop is not None
-    if sys.version_info[1] < 8:
-        task: asyncio.Task[T] = loop.create_task(coroutine)  # name parameter is only supported from 3.8 onward
-    else:
-        task: asyncio.Task[T] = loop.create_task(coroutine, name=name)
-    task.add_done_callback(
-        functools.partial(_handle_task_result, logger=logger, message=message, message_args=message_args)
-    )
+    task = globals.loop.create_task(coroutine, name=name) if name_supported else globals.loop.create_task(coroutine)
+    task.add_done_callback(_handle_task_result)
     running_tasks.add(task)
     task.add_done_callback(running_tasks.discard)
     return task
 
 
-def _handle_task_result(
-    task: asyncio.Task,
-    *,
-    logger: logging.Logger,
-    message: str,
-    message_args: Tuple[Any, ...] = (),
-) -> None:
+def create_lazy(coroutine: Awaitable[T], *, name: str) -> 'asyncio.Task[T]':
+    '''Wraps a create call and ensures a second task with the same name is delayed until the first one is done.
+
+    If a third task with the same name is created while the first one is still running, the second one is discarded.
+    '''
+    if name in lazy_tasks_running:
+        lazy_tasks_waiting[name] = coroutine
+        return
+
+    def finalize(name: str) -> None:
+        lazy_tasks_running.pop(name)
+        if name in lazy_tasks_waiting:
+            create_lazy(lazy_tasks_waiting.pop(name), name=name)
+    task = create(coroutine, name=name)
+    lazy_tasks_running[name] = task
+    task.add_done_callback(lambda _: finalize(name))
+
+
+def _handle_task_result(task: asyncio.Task) -> None:
     try:
         task.result()
     except asyncio.CancelledError:
-        pass  # Task cancellation should not be logged as an error.
-    # Ad the pylint ignore: we want to handle all exceptions here so that the result of the task
-    # is properly logged. There is no point re-raising the exception in this callback.
-    except Exception:  # pylint: disable=broad-except
-        logger.exception(message, *message_args)
+        pass
+    except Exception:
+        logger.exception('Task raised an exception')