1
0

lifespan.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. """Mixin that allow tasks to run during the whole app lifespan."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import dataclasses
  6. import functools
  7. import inspect
  8. from collections.abc import Callable, Coroutine
  9. from starlette.applications import Starlette
  10. from reflex.utils import console
  11. from reflex.utils.exceptions import InvalidLifespanTaskTypeError
  12. from .mixin import AppMixin
  13. @dataclasses.dataclass
  14. class LifespanMixin(AppMixin):
  15. """A Mixin that allow tasks to run during the whole app lifespan."""
  16. # Lifespan tasks that are planned to run.
  17. lifespan_tasks: set[asyncio.Task | Callable] = dataclasses.field(
  18. default_factory=set
  19. )
  20. @contextlib.asynccontextmanager
  21. async def _run_lifespan_tasks(self, app: Starlette):
  22. running_tasks = []
  23. try:
  24. async with contextlib.AsyncExitStack() as stack:
  25. for task in self.lifespan_tasks:
  26. run_msg = f"Started lifespan task: {task.__name__} as {{type}}" # pyright: ignore [reportAttributeAccessIssue]
  27. if isinstance(task, asyncio.Task):
  28. running_tasks.append(task)
  29. else:
  30. signature = inspect.signature(task)
  31. if "app" in signature.parameters:
  32. task = functools.partial(task, app=app)
  33. _t = task()
  34. if isinstance(_t, contextlib._AsyncGeneratorContextManager):
  35. await stack.enter_async_context(_t)
  36. console.debug(run_msg.format(type="asynccontextmanager"))
  37. elif isinstance(_t, Coroutine):
  38. task_ = asyncio.create_task(_t)
  39. task_.add_done_callback(lambda t: t.result())
  40. running_tasks.append(task_)
  41. console.debug(run_msg.format(type="coroutine"))
  42. else:
  43. console.debug(run_msg.format(type="function"))
  44. yield
  45. finally:
  46. for task in running_tasks:
  47. console.debug(f"Canceling lifespan task: {task}")
  48. task.cancel(msg="lifespan_cleanup")
  49. def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
  50. """Register a task to run during the lifespan of the app.
  51. Args:
  52. task: The task to register.
  53. **task_kwargs: The kwargs of the task.
  54. Raises:
  55. InvalidLifespanTaskTypeError: If the task is a generator function.
  56. """
  57. if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task):
  58. raise InvalidLifespanTaskTypeError(
  59. f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
  60. )
  61. if task_kwargs:
  62. original_task = task
  63. task = functools.partial(task, **task_kwargs) # pyright: ignore [reportArgumentType]
  64. functools.update_wrapper(task, original_task) # pyright: ignore [reportArgumentType]
  65. self.lifespan_tasks.add(task)
  66. console.debug(f"Registered lifespan task: {task.__name__}") # pyright: ignore [reportAttributeAccessIssue]