123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- """Mixin that allow tasks to run during the whole app lifespan."""
- from __future__ import annotations
- import asyncio
- import contextlib
- import functools
- import inspect
- import sys
- from typing import Callable, Coroutine, Set, Union
- from fastapi import FastAPI
- from .mixin import AppMixin
- class LifespanMixin(AppMixin):
- """A Mixin that allow tasks to run during the whole app lifespan."""
- # Lifespan tasks that are planned to run.
- lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
- @contextlib.asynccontextmanager
- async def _run_lifespan_tasks(self, app: FastAPI):
- running_tasks = []
- try:
- async with contextlib.AsyncExitStack() as stack:
- for task in self.lifespan_tasks:
- if isinstance(task, asyncio.Task):
- running_tasks.append(task)
- else:
- signature = inspect.signature(task)
- if "app" in signature.parameters:
- task = functools.partial(task, app=app)
- _t = task()
- if isinstance(_t, contextlib._AsyncGeneratorContextManager):
- await stack.enter_async_context(_t)
- elif isinstance(_t, Coroutine):
- running_tasks.append(asyncio.create_task(_t))
- yield
- finally:
- cancel_kwargs = (
- {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
- )
- for task in running_tasks:
- task.cancel(**cancel_kwargs)
- def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
- """Register a task to run during the lifespan of the app.
- Args:
- task: The task to register.
- task_kwargs: The kwargs of the task.
- """
- if task_kwargs:
- task = functools.partial(task, **task_kwargs) # type: ignore
- self.lifespan_tasks.add(task) # type: ignore
|