Jelajahi Sumber

improve lifespan typecheck and debug (#4014)

* add lifespan debug statement

* improve some of the logic for lifespan tasks

* fix partial name with update_wrapper
Thomas Brandého 8 bulan lalu
induk
melakukan
1b3422dab6
2 mengubah file dengan 28 tambahan dan 6 penghapusan
  1. 24 6
      reflex/app_mixins/lifespan.py
  2. 4 0
      reflex/utils/exceptions.py

+ 24 - 6
reflex/app_mixins/lifespan.py

@@ -6,11 +6,13 @@ import asyncio
 import contextlib
 import functools
 import inspect
-import sys
 from typing import Callable, Coroutine, Set, Union
 
 from fastapi import FastAPI
 
+from reflex.utils import console
+from reflex.utils.exceptions import InvalidLifespanTaskType
+
 from .mixin import AppMixin
 
 
@@ -26,6 +28,7 @@ class LifespanMixin(AppMixin):
         try:
             async with contextlib.AsyncExitStack() as stack:
                 for task in self.lifespan_tasks:
+                    run_msg = f"Started lifespan task: {task.__name__} as {{type}}"  # type: ignore
                     if isinstance(task, asyncio.Task):
                         running_tasks.append(task)
                     else:
@@ -35,15 +38,19 @@ class LifespanMixin(AppMixin):
                         _t = task()
                         if isinstance(_t, contextlib._AsyncGeneratorContextManager):
                             await stack.enter_async_context(_t)
+                            console.debug(run_msg.format(type="asynccontextmanager"))
                         elif isinstance(_t, Coroutine):
-                            running_tasks.append(asyncio.create_task(_t))
+                            task_ = asyncio.create_task(_t)
+                            task_.add_done_callback(lambda t: t.result())
+                            running_tasks.append(task_)
+                            console.debug(run_msg.format(type="coroutine"))
+                        else:
+                            console.debug(run_msg.format(type="function"))
                 yield
         finally:
-            cancel_kwargs = (
-                {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
-            )
             for task in running_tasks:
-                task.cancel(**cancel_kwargs)
+                console.debug(f"Canceling lifespan task: {task}")
+                task.cancel(msg="lifespan_cleanup")
 
     def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
         """Register a task to run during the lifespan of the app.
@@ -51,7 +58,18 @@ class LifespanMixin(AppMixin):
         Args:
             task: The task to register.
             task_kwargs: The kwargs of the task.
+
+        Raises:
+            InvalidLifespanTaskType: If the task is a generator function.
         """
+        if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task):
+            raise InvalidLifespanTaskType(
+                f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
+            )
+
         if task_kwargs:
+            original_task = task
             task = functools.partial(task, **task_kwargs)  # type: ignore
+            functools.update_wrapper(task, original_task)  # type: ignore
         self.lifespan_tasks.add(task)  # type: ignore
+        console.debug(f"Registered lifespan task: {task.__name__}")  # type: ignore

+ 4 - 0
reflex/utils/exceptions.py

@@ -111,3 +111,7 @@ class GeneratedCodeHasNoFunctionDefs(ReflexError):
 
 class PrimitiveUnserializableToJSON(ReflexError, ValueError):
     """Raised when a primitive type is unserializable to JSON. Usually with NaN and Infinity."""
+
+
+class InvalidLifespanTaskType(ReflexError, TypeError):
+    """Raised when an invalid task type is registered as a lifespan task."""