task_logger.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. '''original copied from https://quantlane.com/blog/ensure-asyncio-task-exceptions-get-logged/'''
  2. import asyncio
  3. import functools
  4. import logging
  5. import sys
  6. from typing import Any, Awaitable, Optional, Tuple, TypeVar
  7. T = TypeVar('T')
  8. def create_task(
  9. coroutine: Awaitable[T],
  10. *,
  11. loop: Optional[asyncio.AbstractEventLoop] = None,
  12. name: str = 'unnamed task',
  13. ) -> 'asyncio.Task[T]': # This type annotation has to be quoted for Python < 3.9, see https://www.python.org/dev/peps/pep-0585/
  14. '''
  15. This helper function wraps a ``loop.create_task(coroutine())`` call and ensures there is
  16. an exception handler added to the resulting task. If the task raises an exception it is logged
  17. using the provided ``logger``, with additional context provided by ``message`` and optionally
  18. ``message_args``.
  19. '''
  20. logger = logging.getLogger(__name__)
  21. message = 'Task raised an exception'
  22. message_args = ()
  23. if loop is None:
  24. loop = asyncio.get_running_loop()
  25. if sys.version_info[1] < 8:
  26. task = loop.create_task(coroutine) # name parameter is only supported from 3.8 onward
  27. else:
  28. task = loop.create_task(coroutine, name=name)
  29. task.add_done_callback(
  30. functools.partial(_handle_task_result, logger=logger, message=message, message_args=message_args)
  31. )
  32. return task
  33. def _handle_task_result(
  34. task: asyncio.Task,
  35. *,
  36. logger: logging.Logger,
  37. message: str,
  38. message_args: Tuple[Any, ...] = (),
  39. ) -> None:
  40. try:
  41. task.result()
  42. except asyncio.CancelledError:
  43. pass # Task cancellation should not be logged as an error.
  44. # Ad the pylint ignore: we want to handle all exceptions here so that the result of the task
  45. # is properly logged. There is no point re-raising the exception in this callback.
  46. except Exception: # pylint: disable=broad-except
  47. logger.exception(message, *message_args)