Bläddra i källkod

add support for lifespan tasks (#3312)

* add support for lifespan tasks

* allow passing args to lifespan task

* add message to the cancel call

* allow asynccontextmanager as lifespan tasks

* Fix integration.utils.SessionStorage

Previously the SessionStorage util was just looking in localStorage, but the
tests didn't catch it because they were asserting the token was not None,
rather than asserting it was truthy.

Fixed here, because I'm using this structure in the new lifespan test.

* If the lifespan task or context takes "app" parameter, pass the FastAPI instance.

* test_lifespan: end to end test for register_lifespan_task

* In py3.8, Task.cancel takes no args

* test_lifespan: use polling to make the test more robust

Fix CI failure

* Do not allow task_args for better composability

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Thomas Brandého 1 år sedan
förälder
incheckning
956a526b20
5 ändrade filer med 182 tillägg och 13 borttagningar
  1. 1 2
      integration/test_component_state.py
  2. 120 0
      integration/test_lifespan.py
  3. 1 2
      integration/test_navigation.py
  4. 13 7
      integration/utils.py
  5. 47 2
      reflex/app.py

+ 1 - 2
integration/test_component_state.py

@@ -79,8 +79,7 @@ async def test_component_state_app(component_state_app: AppHarness):
     driver = component_state_app.frontend()
 
     ss = utils.SessionStorage(driver)
-    token = AppHarness._poll_for(lambda: ss.get("token") is not None)
-    assert token is not None
+    assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
 
     count_a = driver.find_element(By.ID, "count-a")
     count_b = driver.find_element(By.ID, "count-b")

+ 120 - 0
integration/test_lifespan.py

@@ -0,0 +1,120 @@
+"""Test cases for the FastAPI lifespan integration."""
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import AppHarness
+
+from .utils import SessionStorage
+
+
+def LifespanApp():
+    """App with lifespan tasks and context."""
+    import asyncio
+    from contextlib import asynccontextmanager
+
+    import reflex as rx
+
+    lifespan_task_global = 0
+    lifespan_context_global = 0
+
+    @asynccontextmanager
+    async def lifespan_context(app, inc: int = 1):
+        global lifespan_context_global
+        print(f"Lifespan context entered: {app}.")
+        lifespan_context_global += inc  # pyright: ignore[reportUnboundVariable]
+        try:
+            yield
+        finally:
+            print("Lifespan context exited.")
+            lifespan_context_global += inc
+
+    async def lifespan_task(inc: int = 1):
+        global lifespan_task_global
+        print("Lifespan global started.")
+        try:
+            while True:
+                lifespan_task_global += inc  # pyright: ignore[reportUnboundVariable]
+                await asyncio.sleep(0.1)
+        except asyncio.CancelledError as ce:
+            print(f"Lifespan global cancelled: {ce}.")
+            lifespan_task_global = 0
+
+    class LifespanState(rx.State):
+        @rx.var
+        def task_global(self) -> int:
+            return lifespan_task_global
+
+        @rx.var
+        def context_global(self) -> int:
+            return lifespan_context_global
+
+        def tick(self, date):
+            pass
+
+    def index():
+        return rx.vstack(
+            rx.text(LifespanState.task_global, id="task_global"),
+            rx.text(LifespanState.context_global, id="context_global"),
+            rx.moment(interval=100, on_change=LifespanState.tick),
+        )
+
+    app = rx.App()
+    app.register_lifespan_task(lifespan_task)
+    app.register_lifespan_task(lifespan_context, inc=2)
+    app.add_page(index)
+
+
+@pytest.fixture()
+def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]:
+    """Start LifespanApp app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path: pytest tmp_path fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path,
+        app_source=LifespanApp,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.mark.asyncio
+async def test_lifespan(lifespan_app: AppHarness):
+    """Test the lifespan integration.
+
+    Args:
+        lifespan_app: harness for LifespanApp app
+    """
+    assert lifespan_app.app_module is not None, "app module is not found"
+    assert lifespan_app.app_instance is not None, "app is not running"
+    driver = lifespan_app.frontend()
+
+    ss = SessionStorage(driver)
+    assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
+
+    context_global = driver.find_element(By.ID, "context_global")
+    task_global = driver.find_element(By.ID, "task_global")
+
+    assert context_global.text == "2"
+    assert lifespan_app.app_module.lifespan_context_global == 2  # type: ignore
+
+    original_task_global_text = task_global.text
+    original_task_global_value = int(original_task_global_text)
+    lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text)
+    assert lifespan_app.app_module.lifespan_task_global > original_task_global_value  # type: ignore
+    assert int(task_global.text) > original_task_global_value
+
+    # Kill the backend
+    assert lifespan_app.backend is not None
+    lifespan_app.backend.should_exit = True
+    if lifespan_app.backend_thread is not None:
+        lifespan_app.backend_thread.join()
+
+    # Check that the lifespan tasks have been cancelled
+    assert lifespan_app.app_module.lifespan_task_global == 0
+    assert lifespan_app.app_module.lifespan_context_global == 4

+ 1 - 2
integration/test_navigation.py

@@ -67,8 +67,7 @@ async def test_navigation_app(navigation_app: AppHarness):
     driver = navigation_app.frontend()
 
     ss = SessionStorage(driver)
-    token = AppHarness._poll_for(lambda: ss.get("token") is not None)
-    assert token is not None
+    assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
 
     internal_link = driver.find_element(By.ID, "internal")
 

+ 13 - 7
integration/utils.py

@@ -54,7 +54,9 @@ class LocalStorage:
         Returns:
             The number of items in local storage.
         """
-        return int(self.driver.execute_script("return window.localStorage.length;"))
+        return int(
+            self.driver.execute_script(f"return window.{self.storage_key}.length;")
+        )
 
     def items(self) -> dict[str, str]:
         """Get all items in local storage.
@@ -63,7 +65,7 @@ class LocalStorage:
             A dict mapping keys to values.
         """
         return self.driver.execute_script(
-            "var ls = window.localStorage, items = {}; "
+            f"var ls = window.{self.storage_key}, items = {{}}; "
             "for (var i = 0, k; i < ls.length; ++i) "
             "  items[k = ls.key(i)] = ls.getItem(k); "
             "return items; "
@@ -76,7 +78,7 @@ class LocalStorage:
             A list of keys.
         """
         return self.driver.execute_script(
-            "var ls = window.localStorage, keys = []; "
+            f"var ls = window.{self.storage_key}, keys = []; "
             "for (var i = 0; i < ls.length; ++i) "
             "  keys[i] = ls.key(i); "
             "return keys; "
@@ -92,7 +94,7 @@ class LocalStorage:
             The value of the key.
         """
         return self.driver.execute_script(
-            "return window.localStorage.getItem(arguments[0]);", key
+            f"return window.{self.storage_key}.getItem(arguments[0]);", key
         )
 
     def set(self, key, value) -> None:
@@ -103,7 +105,9 @@ class LocalStorage:
             value: The value to set the key to.
         """
         self.driver.execute_script(
-            "window.localStorage.setItem(arguments[0], arguments[1]);", key, value
+            f"window.{self.storage_key}.setItem(arguments[0], arguments[1]);",
+            key,
+            value,
         )
 
     def has(self, key) -> bool:
@@ -123,11 +127,13 @@ class LocalStorage:
         Args:
             key: The key to remove.
         """
-        self.driver.execute_script("window.localStorage.removeItem(arguments[0]);", key)
+        self.driver.execute_script(
+            f"window.{self.storage_key}.removeItem(arguments[0]);", key
+        )
 
     def clear(self) -> None:
         """Clear all local storage."""
-        self.driver.execute_script("window.localStorage.clear();")
+        self.driver.execute_script(f"window.{self.storage_key}.clear();")
 
     def __getitem__(self, key) -> str:
         """Get a key from local storage.

+ 47 - 2
reflex/app.py

@@ -7,10 +7,12 @@ import concurrent.futures
 import contextlib
 import copy
 import functools
+import inspect
 import io
 import multiprocessing
 import os
 import platform
+import sys
 from typing import (
     Any,
     AsyncIterator,
@@ -100,7 +102,50 @@ class OverlayFragment(Fragment):
     pass
 
 
-class App(Base):
+class LifespanMixin(Base):
+    """A Mixin that allow tasks to run during the whole app lifespan."""
+
+    # Lifespan tasks that are planned to run.
+    lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
+
+    @contextlib.asynccontextmanager
+    async def _run_lifespan_tasks(self, app: FastAPI):
+        running_tasks = []
+        try:
+            async with contextlib.AsyncExitStack() as stack:
+                for task in self.lifespan_tasks:
+                    if isinstance(task, asyncio.Task):
+                        running_tasks.append(task)
+                    else:
+                        signature = inspect.signature(task)
+                        if "app" in signature.parameters:
+                            task = functools.partial(task, app=app)
+                        _t = task()
+                        if isinstance(_t, contextlib._AsyncGeneratorContextManager):
+                            await stack.enter_async_context(_t)
+                        elif isinstance(_t, Coroutine):
+                            running_tasks.append(asyncio.create_task(_t))
+                yield
+        finally:
+            cancel_kwargs = (
+                {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
+            )
+            for task in running_tasks:
+                task.cancel(**cancel_kwargs)
+
+    def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
+        """Register a task to run during the lifespan of the app.
+
+        Args:
+            task: The task to register.
+            task_kwargs: The kwargs of the task.
+        """
+        if task_kwargs:
+            task = functools.partial(task, **task_kwargs)  # type: ignore
+        self.lifespan_tasks.add(task)  # type: ignore
+
+
+class App(LifespanMixin, Base):
     """The main Reflex app that encapsulates the backend and frontend.
 
     Every Reflex app needs an app defined in its main module.
@@ -203,7 +248,7 @@ class App(Base):
         self.middleware.append(HydrateMiddleware())
 
         # Set up the API.
-        self.api = FastAPI()
+        self.api = FastAPI(lifespan=self._run_lifespan_tasks)
         self._add_cors()
         self._add_default_endpoints()