lifespan.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. """Mixin that allow tasks to run during the whole app lifespan."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import functools
  6. import inspect
  7. import sys
  8. from typing import Callable, Coroutine, Set, Union
  9. from fastapi import FastAPI
  10. from .mixin import AppMixin
  11. class LifespanMixin(AppMixin):
  12. """A Mixin that allow tasks to run during the whole app lifespan."""
  13. # Lifespan tasks that are planned to run.
  14. lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
  15. @contextlib.asynccontextmanager
  16. async def _run_lifespan_tasks(self, app: FastAPI):
  17. running_tasks = []
  18. try:
  19. async with contextlib.AsyncExitStack() as stack:
  20. for task in self.lifespan_tasks:
  21. if isinstance(task, asyncio.Task):
  22. running_tasks.append(task)
  23. else:
  24. signature = inspect.signature(task)
  25. if "app" in signature.parameters:
  26. task = functools.partial(task, app=app)
  27. _t = task()
  28. if isinstance(_t, contextlib._AsyncGeneratorContextManager):
  29. await stack.enter_async_context(_t)
  30. elif isinstance(_t, Coroutine):
  31. running_tasks.append(asyncio.create_task(_t))
  32. yield
  33. finally:
  34. cancel_kwargs = (
  35. {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
  36. )
  37. for task in running_tasks:
  38. task.cancel(**cancel_kwargs)
  39. def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
  40. """Register a task to run during the lifespan of the app.
  41. Args:
  42. task: The task to register.
  43. task_kwargs: The kwargs of the task.
  44. """
  45. if task_kwargs:
  46. task = functools.partial(task, **task_kwargs) # type: ignore
  47. self.lifespan_tasks.add(task) # type: ignore