lifespan.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. """Mixin that allow tasks to run during the whole app lifespan."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import functools
  6. import inspect
  7. from typing import Callable, Coroutine, Set, Union
  8. from fastapi import FastAPI
  9. from reflex.utils import console
  10. from reflex.utils.exceptions import InvalidLifespanTaskType
  11. from .mixin import AppMixin
  12. class LifespanMixin(AppMixin):
  13. """A Mixin that allow tasks to run during the whole app lifespan."""
  14. # Lifespan tasks that are planned to run.
  15. lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
  16. @contextlib.asynccontextmanager
  17. async def _run_lifespan_tasks(self, app: FastAPI):
  18. running_tasks = []
  19. try:
  20. async with contextlib.AsyncExitStack() as stack:
  21. for task in self.lifespan_tasks:
  22. run_msg = f"Started lifespan task: {task.__name__} as {{type}}" # type: ignore
  23. if isinstance(task, asyncio.Task):
  24. running_tasks.append(task)
  25. else:
  26. signature = inspect.signature(task)
  27. if "app" in signature.parameters:
  28. task = functools.partial(task, app=app)
  29. _t = task()
  30. if isinstance(_t, contextlib._AsyncGeneratorContextManager):
  31. await stack.enter_async_context(_t)
  32. console.debug(run_msg.format(type="asynccontextmanager"))
  33. elif isinstance(_t, Coroutine):
  34. task_ = asyncio.create_task(_t)
  35. task_.add_done_callback(lambda t: t.result())
  36. running_tasks.append(task_)
  37. console.debug(run_msg.format(type="coroutine"))
  38. else:
  39. console.debug(run_msg.format(type="function"))
  40. yield
  41. finally:
  42. for task in running_tasks:
  43. console.debug(f"Canceling lifespan task: {task}")
  44. task.cancel(msg="lifespan_cleanup")
  45. def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
  46. """Register a task to run during the lifespan of the app.
  47. Args:
  48. task: The task to register.
  49. task_kwargs: The kwargs of the task.
  50. Raises:
  51. InvalidLifespanTaskType: If the task is a generator function.
  52. """
  53. if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task):
  54. raise InvalidLifespanTaskType(
  55. f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
  56. )
  57. if task_kwargs:
  58. original_task = task
  59. task = functools.partial(task, **task_kwargs) # type: ignore
  60. functools.update_wrapper(task, original_task) # type: ignore
  61. self.lifespan_tasks.add(task) # type: ignore
  62. console.debug(f"Registered lifespan task: {task.__name__}") # type: ignore