Browse Source

Properly await coroutines registered with app.on_shutdown (#4641)

This PR started with the intent to fix #4592 which reported unawaited
coroutines when running pytests. It turns out, the warning lead down a
rabbit hole of problems. Here is a list of the major findings and fixes:

- the coroutines registered with `app.on_shutdown` have not been awaited
properly
- `background_tasks` where not canceled on shutdown
- `outbox.loop` did not exit when receiving a cancel command
- handle cancellation in internal house-keeping loops (pruning, binding,
...)
- before Python 3.12 `asyncio.wait_for` misses cancel commands in
certain conditions and must be fixed with
[`wait-for2`](https://pypi.org/project/wait-for2/)
- ensure that all storage backups are written before exiting
- introduce `@background_tasks.await_on_shutdown` annotation to mark
coroutines as "not to cancel on shutdown" (and thereby also fixes #4312)
- ~~clean up multiprocessing to not get "leaked semaphore" warnings (fix
was described in
https://github.com/zauberzeug/nicegui/issues/4131#issuecomment-2705100273)~~

Also this PR makes some minor improvements:

- disable pytest warning from upstream packages which we can do nothing
about
- fix UTC warnings
- add names to background_tasks where ever possible
- only stop screen tests if console output contains "ERROR" (same as
#4608 did for the user tests)
- first steps to add some `background_tasks` documentation

---------

Co-authored-by: Falko Schindler <falko@zauberzeug.com>
Rodja Trappe 2 weeks ago
parent
commit
2b8b1a6969

+ 2 - 2
.github/workflows/update_version.py

@@ -1,6 +1,6 @@
 #!/usr/bin/env python3
 import sys
-from datetime import UTC, datetime
+from datetime import datetime, timezone
 from pathlib import Path
 
 if __name__ == '__main__':
@@ -20,5 +20,5 @@ if __name__ == '__main__':
         if line.startswith('version: '):
             lines[i] = f'version: {version}'
         if line.startswith('date-released: '):
-            lines[i] = f'date-released: "{datetime.now(UTC).strftime(r"%Y-%m-%d")}"'
+            lines[i] = f'date-released: "{datetime.now(timezone.utc).strftime(r"%Y-%m-%d")}"'
     path.write_text('\n'.join(lines) + '\n', encoding='utf-8')

+ 3 - 4
nicegui/air.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import asyncio
 import gzip
 import json
 import logging
@@ -225,7 +224,7 @@ class Air:
         self.connecting = True
         try:
             if self.relay.connected:
-                await asyncio.wait_for(self.disconnect(), timeout=5)
+                await helpers.wait_for(self.disconnect(), timeout=5)
             self.log.debug('Connecting...')
             await self.relay.connect(
                 f'{RELAY_HOST}?device_token={self.token}',
@@ -270,10 +269,10 @@ class Air:
 def connect() -> None:
     """Connect to the NiceGUI On Air server if there is an air instance."""
     if core.air:
-        background_tasks.create(core.air.connect())
+        background_tasks.create(core.air.connect(), name='On Air connect')
 
 
 def disconnect() -> None:
     """Disconnect from the NiceGUI On Air server if there is an air instance."""
     if core.air:
-        background_tasks.create(core.air.disconnect())
+        background_tasks.create(core.air.disconnect(), name='On Air disconnect')

+ 12 - 6
nicegui/app/app.py

@@ -46,8 +46,6 @@ class App(FastAPI):
         self._disconnect_handlers: List[Union[Callable[..., Any], Awaitable]] = []
         self._exception_handlers: List[Callable[..., Any]] = [log.exception]
 
-        self.on_shutdown(self.storage.on_shutdown)
-
     @property
     def is_starting(self) -> bool:
         """Return whether NiceGUI is starting."""
@@ -73,13 +71,21 @@ class App(FastAPI):
         self._state = State.STARTING
         for t in self._startup_handlers:
             Client.auto_index_client.safe_invoke(t)
+        self.on_shutdown(self.storage.on_shutdown)
+        self.on_shutdown(background_tasks.teardown)
         self._state = State.STARTED
 
-    def stop(self) -> None:
+    async def stop(self) -> None:
         """Stop NiceGUI. (For internal use only.)"""
         self._state = State.STOPPING
-        for t in self._shutdown_handlers:
-            Client.auto_index_client.safe_invoke(t)
+        with Client.auto_index_client:
+            for t in self._shutdown_handlers:
+                if isinstance(t, Awaitable):
+                    await t
+                else:
+                    result = t(self) if len(inspect.signature(t).parameters) == 1 else t()
+                    if helpers.is_coroutine_function(t):
+                        await result
         self._state = State.STOPPED
 
     def on_connect(self, handler: Union[Callable, Awaitable]) -> None:
@@ -124,7 +130,7 @@ class App(FastAPI):
         for handler in self._exception_handlers:
             result = handler() if not inspect.signature(handler).parameters else handler(exception)
             if helpers.is_coroutine_function(handler):
-                background_tasks.create(result)
+                background_tasks.create(result, name=f'exception {handler.__name__}')
 
     def shutdown(self) -> None:
         """Shut down NiceGUI.

+ 2 - 2
nicegui/app/range_response.py

@@ -1,6 +1,6 @@
 import hashlib
 import mimetypes
-from datetime import datetime
+from datetime import datetime, timezone
 from pathlib import Path
 from typing import Generator
 
@@ -13,7 +13,7 @@ mimetypes.init()
 def get_range_response(file: Path, request: Request, chunk_size: int) -> Response:
     """Get a Response for the given file, supporting range-requests, E-Tag and Last-Modified."""
     file_size = file.stat().st_size
-    last_modified_time = datetime.utcfromtimestamp(file.stat().st_mtime)
+    last_modified_time = datetime.fromtimestamp(file.stat().st_mtime, timezone.utc)
     start = 0
     end = file_size - 1
     status_code = 200

+ 60 - 9
nicegui/background_tasks.py

@@ -2,13 +2,16 @@
 from __future__ import annotations
 
 import asyncio
-from typing import Awaitable, Dict, Set
+import weakref
+from typing import Any, Awaitable, Callable, Coroutine, Dict, Set, TypeVar
 
-from . import core
+from . import core, helpers
+from .logging import log
 
 running_tasks: Set[asyncio.Task] = set()
 lazy_tasks_running: Dict[str, asyncio.Task] = {}
-lazy_tasks_waiting: Dict[str, Awaitable] = {}
+lazy_coroutines_waiting: Dict[str, Coroutine[Any, Any, Any]] = {}
+functions_awaited_on_shutdown: weakref.WeakSet[Callable] = weakref.WeakSet()
 
 
 def create(coroutine: Awaitable, *, name: str = 'unnamed task') -> asyncio.Task:
@@ -19,7 +22,7 @@ def create(coroutine: Awaitable, *, name: str = 'unnamed task') -> asyncio.Task:
     See https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task.
     """
     assert core.loop is not None
-    coroutine = coroutine if asyncio.iscoroutine(coroutine) else asyncio.wait_for(coroutine, None)
+    coroutine = coroutine if asyncio.iscoroutine(coroutine) else helpers.wait_for(coroutine, None)
     task: asyncio.Task = core.loop.create_task(coroutine, name=name)
     task.add_done_callback(_handle_task_result)
     running_tasks.add(task)
@@ -33,20 +36,39 @@ def create_lazy(coroutine: Awaitable, *, name: str) -> None:
     If a third task with the same name is created while the first one is still running, the second one is discarded.
     """
     if name in lazy_tasks_running:
-        if name in lazy_tasks_waiting:
-            asyncio.Task(lazy_tasks_waiting[name]).cancel()
-        lazy_tasks_waiting[name] = coroutine
+        if name in lazy_coroutines_waiting:
+            lazy_coroutines_waiting[name].close()
+        lazy_coroutines_waiting[name] = _ensure_coroutine(coroutine)
         return
 
     def finalize(name: str) -> None:
         lazy_tasks_running.pop(name)
-        if name in lazy_tasks_waiting:
-            create_lazy(lazy_tasks_waiting.pop(name), name=name)
+        if name in lazy_coroutines_waiting:
+            create_lazy(lazy_coroutines_waiting.pop(name), name=name)
     task = create(coroutine, name=name)
     lazy_tasks_running[name] = task
     task.add_done_callback(lambda _: finalize(name))
 
 
+F = TypeVar('F', bound=Callable)
+
+
+def await_on_shutdown(func: F) -> F:
+    """Tag a coroutine function so tasks created from it won't be cancelled during shutdown."""
+    functions_awaited_on_shutdown.add(func)
+    return func
+
+
+def _ensure_coroutine(awaitable: Awaitable[Any]) -> Coroutine[Any, Any, Any]:
+    """Convert an awaitable to a coroutine if it isn't already one."""
+    if asyncio.iscoroutine(awaitable):
+        return awaitable
+
+    async def wrapper() -> Any:
+        return await awaitable
+    return wrapper()
+
+
 def _handle_task_result(task: asyncio.Task) -> None:
     try:
         task.result()
@@ -54,3 +76,32 @@ def _handle_task_result(task: asyncio.Task) -> None:
         pass
     except Exception as e:
         core.app.handle_exception(e)
+
+
+async def teardown() -> None:
+    """Cancel all running tasks and coroutines on shutdown. (For internal use only.)"""
+    while running_tasks or lazy_tasks_running:
+        tasks = running_tasks | set(lazy_tasks_running.values())
+        for task in tasks:
+            if not task.done() and not task.cancelled() and not _should_await_on_shutdown(task):
+                task.cancel()
+        if tasks:
+            await asyncio.sleep(0)  # NOTE: ensure the loop can cancel the tasks before it shuts down
+            try:
+                await helpers.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=2.0)
+            except asyncio.TimeoutError:
+                log.error('Could not cancel %s tasks within timeout: %s',
+                          len(tasks),
+                          ', '.join(t.get_name() for t in tasks if not t.done()))
+            except Exception:
+                log.exception('Error while cancelling tasks')
+    for coro in lazy_coroutines_waiting.values():
+        coro.close()
+
+
+def _should_await_on_shutdown(task: asyncio.Task) -> bool:
+    try:
+        return any(fn.__code__ is task.get_coro().cr_frame.f_code  # type: ignore
+                   for fn in functions_awaited_on_shutdown)
+    except AttributeError:
+        return False

+ 4 - 1
nicegui/binding.py

@@ -67,7 +67,10 @@ async def refresh_loop() -> None:
     """Refresh all bindings in an endless loop."""
     while True:
         _refresh_step()
-        await asyncio.sleep(core.app.config.binding_refresh_interval)
+        try:
+            await asyncio.sleep(core.app.config.binding_refresh_interval)
+        except asyncio.CancelledError:
+            break
 
 
 @contextmanager

+ 10 - 5
nicegui/client.py

@@ -280,7 +280,8 @@ class Client:
                 self._delete_tasks.pop(document_id)
                 if not self.shared:
                     self.delete()
-        self._delete_tasks[document_id] = background_tasks.create(delete_content())
+        self._delete_tasks[document_id] = \
+            background_tasks.create(delete_content(), name=f'delete content {document_id}')
 
     def _cancel_delete_task(self, document_id: str) -> None:
         if document_id in self._delete_tasks:
@@ -302,20 +303,21 @@ class Client:
 
     def safe_invoke(self, func: Union[Callable[..., Any], Awaitable]) -> None:
         """Invoke the potentially async function in the client context and catch any exceptions."""
+        func_name = func.__name__ if hasattr(func, '__name__') else str(func)
         try:
             if isinstance(func, Awaitable):
                 async def func_with_client():
                     with self:
                         await func
-                background_tasks.create(func_with_client())
+                background_tasks.create(func_with_client(), name=f'func with client {self.id} {func_name}')
             else:
                 with self:
                     result = func(self) if len(inspect.signature(func).parameters) == 1 else func()
-                if helpers.is_coroutine_function(func):
+                if helpers.is_coroutine_function(func) and not isinstance(result, asyncio.Task):
                     async def result_with_client():
                         with self:
                             await result
-                    background_tasks.create(result_with_client())
+                    background_tasks.create(result_with_client(), name=f'result with client {self.id} {func_name}')
         except Exception as e:
             core.app.handle_exception(e)
 
@@ -377,4 +379,7 @@ class Client:
             except Exception:
                 # NOTE: make sure the loop doesn't crash
                 log.exception('Error while pruning clients')
-            await asyncio.sleep(10)
+            try:
+                await asyncio.sleep(10)
+            except asyncio.CancelledError:
+                break

+ 1 - 1
nicegui/elements/mixins/validation_element.py

@@ -70,7 +70,7 @@ class ValidationElement(ValueElement):
                 self.error = await result
             if return_result:
                 raise NotImplementedError('The validate method cannot return results for async validation functions.')
-            background_tasks.create(await_error())
+            background_tasks.create(await_error(), name=f'validate {self.id}')
             return True
 
         if callable(self._validation):

+ 2 - 1
nicegui/events.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import asyncio
 from contextlib import nullcontext
 from dataclasses import dataclass
 from inspect import Parameter, signature
@@ -436,7 +437,7 @@ def handle_event(handler: Optional[Handler[EventT]], arguments: EventT) -> None:
                 result = cast(Callable[[EventT], Any], handler)(arguments)
             else:
                 result = cast(Callable[[], Any], handler)()
-        if isinstance(result, Awaitable) and not isinstance(result, AwaitableResponse):
+        if isinstance(result, Awaitable) and not isinstance(result, AwaitableResponse) and not isinstance(result, asyncio.Task):
             # NOTE: await an awaitable result even if the handler is not a coroutine (like a lambda statement)
             async def wait_for_result():
                 with parent_slot:

+ 1 - 1
nicegui/functions/refreshable.py

@@ -115,7 +115,7 @@ class refreshable(Generic[_P, _T]):
             if is_coroutine_function(self.func):
                 assert isinstance(result, Awaitable)
                 if core.loop and core.loop.is_running():
-                    background_tasks.create(result)
+                    background_tasks.create(result, name=f'refresh {self.func.__name__}')
                 else:
                     core.app.on_startup(result)
 

+ 12 - 0
nicegui/helpers.py

@@ -6,9 +6,12 @@ import socket
 import threading
 import time
 import webbrowser
+from collections.abc import Awaitable
 from pathlib import Path
 from typing import Any, Optional, Set, Tuple, Union
 
+import wait_for2
+
 from .logging import log
 
 _shown_warnings: Set[str] = set()
@@ -106,3 +109,12 @@ def kebab_to_camel_case(string: str) -> str:
 def event_type_to_camel_case(string: str) -> str:
     """Convert an event type string to camelCase."""
     return '.'.join(kebab_to_camel_case(part) if part != '-' else part for part in string.split('.'))
+
+
+async def wait_for(fut: Awaitable, timeout: Optional[float] = None) -> None:
+    """Wait for a future to complete.
+
+    This function is a wrapper around ``wait_for2.wait_for`` which is a drop-in replacement for ``asyncio.wait_for``.
+    It can be removed once we drop support for older versions than Python 3.13 which fixes ``asyncio.wait_for``.
+    """
+    return await wait_for2.wait_for(fut, timeout)

+ 3 - 1
nicegui/javascript_request.py

@@ -3,6 +3,8 @@ from __future__ import annotations
 import asyncio
 from typing import Any, ClassVar, Dict
 
+from . import helpers
+
 
 class JavaScriptRequest:
     _instances: ClassVar[Dict[str, JavaScriptRequest]] = {}
@@ -23,7 +25,7 @@ class JavaScriptRequest:
 
     def __await__(self) -> Any:
         try:
-            yield from asyncio.wait_for(self._event.wait(), self.timeout).__await__()
+            yield from helpers.wait_for(self._event.wait(), self.timeout).__await__()
         except asyncio.TimeoutError as e:
             raise TimeoutError(f'JavaScript did not respond within {self.timeout:.1f} s') from e
         else:

+ 1 - 1
nicegui/nicegui.py

@@ -139,7 +139,7 @@ async def _shutdown() -> None:
     if app.native.main_window:
         app.native.main_window.signal_server_shutdown()
     air.disconnect()
-    app.stop()
+    await app.stop()
     run.tear_down()
 
 

+ 4 - 3
nicegui/outbox.py

@@ -5,7 +5,7 @@ import time
 from collections import deque
 from typing import TYPE_CHECKING, Any, Deque, Dict, Optional, Tuple
 
-from . import background_tasks, core
+from . import background_tasks, core, helpers
 
 if TYPE_CHECKING:
     from .client import Client
@@ -72,7 +72,7 @@ class Outbox:
             try:
                 if not self._enqueue_event.is_set():
                     try:
-                        await asyncio.wait_for(self._enqueue_event.wait(), timeout=1.0)
+                        await helpers.wait_for(self._enqueue_event.wait(), timeout=1.0)
                     except (TimeoutError, asyncio.TimeoutError):
                         continue
 
@@ -101,7 +101,8 @@ class Outbox:
                         await coro
                     except Exception as e:
                         core.app.handle_exception(e)
-
+            except asyncio.CancelledError:
+                break
             except Exception as e:
                 core.app.handle_exception(e)
                 await asyncio.sleep(0.1)

+ 6 - 4
nicegui/persistence/file_persistent_dict.py

@@ -45,13 +45,15 @@ class FilePersistentDict(PersistentDict):
                 return
             self.filepath.parent.mkdir(exist_ok=True)
 
-        async def backup() -> None:
+        @background_tasks.await_on_shutdown
+        async def async_backup() -> None:
             async with aiofiles.open(self.filepath, 'w', encoding=self.encoding) as f:
                 await f.write(json.dumps(self, indent=self.indent))
-        if core.loop:
-            background_tasks.create_lazy(backup(), name=self.filepath.stem)
+
+        if core.loop and core.loop.is_running():
+            background_tasks.create_lazy(async_backup(), name=self.filepath.stem)
         else:
-            core.app.on_startup(backup())
+            self.filepath.write_text(json.dumps(self, indent=self.indent), encoding=self.encoding)
 
     def clear(self) -> None:
         super().clear()

+ 4 - 1
nicegui/slot.py

@@ -59,7 +59,10 @@ class Slot:
             except Exception:
                 # NOTE: make sure the loop doesn't crash
                 log.exception('Error while pruning slot stacks')
-            await asyncio.sleep(10)
+            try:
+                await asyncio.sleep(10)
+            except asyncio.CancelledError:
+                break
 
 
 def get_task_id() -> int:

+ 5 - 2
nicegui/storage.py

@@ -189,7 +189,10 @@ class Storage:
                     if isinstance(tab, PersistentDict):
                         await tab.close()
                     del self._tabs[tab_id]
-            await asyncio.sleep(PURGE_INTERVAL)
+            try:
+                await asyncio.sleep(PURGE_INTERVAL)
+            except asyncio.CancelledError:
+                break
 
     def clear(self) -> None:
         """Clears all storage."""
@@ -208,7 +211,7 @@ class Storage:
             self.path.rmdir()
 
     async def on_shutdown(self) -> None:
-        """Close all persistent storage."""
+        """Close all persistent storage. (For internal use only.)"""
         for user in self._users.values():
             await user.close()
         await self._general.close()

+ 2 - 2
nicegui/testing/screen_plugin.py

@@ -75,11 +75,11 @@ def screen(nicegui_reset_globals,  # noqa: F811, pylint: disable=unused-argument
     prepare_simulation(request)
     screen_ = Screen(nicegui_driver, caplog)
     yield screen_
-    logs = screen_.caplog.get_records('call')
+    logs = [record for record in screen_.caplog.get_records('call') if record.levelname == 'ERROR']
     if screen_.is_open:
         screen_.shot(request.node.name)
     screen_.stop_server()
     if DOWNLOAD_DIR.exists():
         shutil.rmtree(DOWNLOAD_DIR)
     if logs:
-        pytest.fail('There were unexpected logs. See "Captured log call" below.', pytrace=False)
+        pytest.fail('There were unexpected ERROR logs.', pytrace=False)

+ 2 - 1
nicegui/testing/user_download.py

@@ -21,7 +21,8 @@ class UserDownload(Download):
         self.user = user
 
     def __call__(self, src: Union[str, Path, bytes], filename: Optional[str] = None, media_type: str = '') -> Any:
-        background_tasks.create(self._get(src))
+        background_tasks.create(self._get(src),
+                                name=f'download {str(src[:10]) + "..." if isinstance(src, bytes) else src}')
 
     def file(self, path: Union[str, Path], filename: Optional[str] = None, media_type: str = '') -> None:
         self(path)

+ 1 - 1
nicegui/testing/user_interaction.py

@@ -72,7 +72,7 @@ class UserInteraction(Generic[T]):
             for element in self.elements:
                 if isinstance(element, ui.link):
                     href = element.props.get('href', '#')
-                    background_tasks.create(self.user.open(href))
+                    background_tasks.create(self.user.open(href), name=f'open {href}')
                     return self
 
                 if isinstance(element, ui.select):

+ 7 - 4
nicegui/testing/user_navigate.py

@@ -21,20 +21,23 @@ class UserNavigate(Navigate):
             # NOTE navigation to an element does not do anything in the user simulation (the whole content is always visible)
             return
         path = Client.page_routes[target] if callable(target) else target
-        background_tasks.create(self.user.open(path))
+        background_tasks.create(self.user.open(path), name=f'navigate to {path}')
 
     def back(self) -> None:
         current = self.user.back_history.pop()
         self.user.forward_history.append(current)
         target = self.user.back_history.pop()
-        background_tasks.create(self.user.open(target, clear_forward_history=False))
+        background_tasks.create(self.user.open(target, clear_forward_history=False), name=f'navigate back to {target}')
 
     def forward(self) -> None:
         if not self.user.forward_history:
             return
         target = self.user.forward_history[0]
         del self.user.forward_history[0]
-        background_tasks.create(self.user.open(target, clear_forward_history=False))
+        background_tasks.create(self.user.open(target, clear_forward_history=False),
+                                name=f'navigate forward to {target}')
 
     def reload(self) -> None:
-        background_tasks.create(self.user.open(self.user.back_history.pop(), clear_forward_history=False))
+        target = self.user.back_history.pop()
+        background_tasks.create(self.user.open(target, clear_forward_history=False),
+                                name=f'navigate reload to {target}')

+ 12 - 1
poetry.lock

@@ -3951,6 +3951,17 @@ platformdirs = ">=3.9.1,<5"
 docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
 test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"GraalVM\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""]
 
+[[package]]
+name = "wait-for2"
+version = "0.3.2"
+description = "Asyncio wait_for that can handle simultaneous cancellation and future completion."
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+    {file = "wait_for2-0.3.2.tar.gz", hash = "sha256:93863026dc35f3471104ecf7de1f4a0b31f4c8b12a2241c0d6ee26dcc0c2092a"},
+]
+
 [[package]]
 name = "watchdog"
 version = "4.0.2"
@@ -4397,4 +4408,4 @@ sass = ["libsass"]
 [metadata]
 lock-version = "2.1"
 python-versions = "^3.8"
-content-hash = "f43d79075a6b85c874c493c1b888c22ea0ba72e55d8a42c85fc59d8fd4be73fd"
+content-hash = "fd49862eeebd95f8eb1fe6d09e3a98ee41c2d1003f54b616745e8a30df33ae71"

+ 7 - 0
pyproject.toml

@@ -42,6 +42,7 @@ urllib3 = ">=1.26.18,!=2.0.0,!=2.0.1,!=2.0.2,!=2.0.3,!=2.0.4,!=2.0.5,!=2.0.6,!=2
 certifi = ">=2024.07.04" # https://github.com/zauberzeug/nicegui/security/dependabot/35
 redis = { version = ">=4.0.0", optional = true }
 h11 = ">=0.16.0" # https://github.com/zauberzeug/nicegui/security/dependabot/45
+wait_for2 = ">=0.3.2"
 
 [tool.poetry.extras]
 native = ["pywebview"]
@@ -101,6 +102,11 @@ addopts = "--driver Chrome"
 asyncio_mode = "auto"
 testpaths = ["tests"]
 asyncio_default_fixture_loop_scope = "function"
+filterwarnings = [
+  'ignore::DeprecationWarning:^vbuild(\.|$)',
+  'ignore::DeprecationWarning:^websockets\.legacy(\.|$)',
+  'ignore::DeprecationWarning:^uvicorn\.protocols\.websockets(\.|$)',
+]
 
 [tool.mypy]
 python_version = "3.8"
@@ -118,6 +124,7 @@ module = [
     "sass",
     "socketio.*",
     "vbuild",
+    "wait_for2",
     "webview.*", # can be removed with next pywebview release
 ]
 ignore_missing_imports = true

+ 70 - 0
tests/test_background_tasks.py

@@ -0,0 +1,70 @@
+import asyncio
+
+import pytest
+
+from nicegui import app, background_tasks, ui
+from nicegui.testing import User
+
+# pylint: disable=missing-function-docstring
+
+
+# NOTE: click handlers, and system events used to wrap background_task in a background_task (see https://github.com/zauberzeug/nicegui/pull/4641#issuecomment-2837448265)
+@pytest.mark.parametrize('strategy', ['direct', 'click', 'system'])
+async def test_awaiting_background_tasks_on_shutdown(user: User, strategy: str):
+    run = set()
+    cancelled = set()
+
+    async def one():
+        try:
+            run.add('one')
+            await asyncio.sleep(1)
+        except asyncio.CancelledError:
+            cancelled.add('one')
+
+    @background_tasks.await_on_shutdown
+    async def two():
+        try:
+            run.add('two')
+            await asyncio.sleep(1)
+            background_tasks.create(three(), name='three')
+            background_tasks.create(four(), name='four')
+        except asyncio.CancelledError:
+            cancelled.add('two')
+
+    async def three():
+        try:
+            await asyncio.sleep(0.1)
+            run.add('three')
+        except asyncio.CancelledError:
+            cancelled.add('three')
+
+    @background_tasks.await_on_shutdown
+    async def four():
+        try:
+            await asyncio.sleep(0.1)
+            run.add('four')
+        except asyncio.CancelledError:
+            cancelled.add('four')
+
+    ui.button('One', on_click=lambda: background_tasks.create(one(), name='one'))
+    ui.button('Two', on_click=lambda: background_tasks.create(two(), name='two'))
+
+    if strategy == 'system':
+        app.on_connect(lambda: background_tasks.create(one(), name='one'))
+        app.on_connect(lambda: background_tasks.create(two(), name='two'))
+
+    await user.open('/')
+
+    if strategy == 'click':
+        user.find('One').click()
+        user.find('Two').click()
+    elif strategy == 'direct':
+        background_tasks.create(one(), name='one')
+        background_tasks.create(two(), name='two')
+
+    await asyncio.sleep(0.1)  # NOTE: we need to wait for the tasks to be created
+
+    # NOTE: teardown is called on shutdown; here we call it directly to test the teardown logic while test is still running
+    await background_tasks.teardown()
+    assert cancelled == {'one', 'three'}
+    assert run == {'one', 'two', 'four'}

+ 6 - 5
tests/test_storage.py

@@ -1,5 +1,6 @@
 import asyncio
 import copy
+import time
 from pathlib import Path
 
 import httpx
@@ -79,7 +80,7 @@ def test_user_storage_modifications(screen: Screen):
     screen.should_contain('3')
 
 
-async def test_access_user_storage_from_fastapi(screen: Screen):
+def test_access_user_storage_from_fastapi(screen: Screen):
     @app.get('/api')
     def api():
         app.storage.user['msg'] = 'yes'
@@ -87,11 +88,11 @@ async def test_access_user_storage_from_fastapi(screen: Screen):
 
     screen.ui_run_kwargs['storage_secret'] = 'just a test'
     screen.open('/')
-    async with httpx.AsyncClient() as http_client:
-        response = await http_client.get(f'http://localhost:{Screen.PORT}/api')
+    with httpx.Client() as http_client:
+        response = http_client.get(f'http://localhost:{Screen.PORT}/api')
         assert response.status_code == 200
         assert response.text == '"OK"'
-        await asyncio.sleep(0.5)  # wait for storage to be written
+        time.sleep(0.5)  # wait for storage to be written
         assert next(Path('.nicegui').glob('storage-user-*.json')).read_text(encoding='utf-8') == '{"msg":"yes"}'
 
 
@@ -122,7 +123,7 @@ def test_access_user_storage_from_button_click_handler(screen: Screen):
         next(Path('.nicegui').glob('storage-user-*.json')).read_text(encoding='utf-8') == '{"inner_function":"works"}'
 
 
-async def test_access_user_storage_from_background_task(screen: Screen):
+def test_access_user_storage_from_background_task(screen: Screen):
     @ui.page('/')
     def page():
         async def subtask():

+ 1 - 0
website/documentation/content/overview.py

@@ -446,6 +446,7 @@ def map_of_nicegui():
 
         - `create()`: create a background task
         - `create_lazy()`: prevent two tasks with the same name from running at the same time
+        - `await_on_shutdown`: mark a coroutine function to be awaited during shutdown (by default all background tasks are cancelled)
 
         #### `run`
 

+ 29 - 0
website/documentation/content/section_configuration_deployment.py

@@ -119,6 +119,35 @@ def env_var_demo():
     ui.label(f'Markdown content cache size is {markdown.prepare_content.cache_info().maxsize}')
 
 
+@doc.demo('Background Tasks', '''
+    `background_tasks.create()` allows you to run an async function in the background and return a task object.
+    By default the task will be automatically cancelled during shutdown.
+    You can prevent this by using the `@background_tasks.await_on_shutdown` decorator.
+    This is useful for tasks that need to be completed even when the app is shutting down.
+''')
+def background_tasks_demo():
+    from nicegui import background_tasks
+    import asyncio
+    import aiofiles
+
+    results = {'answer': '?'}
+
+    async def compute() -> None:
+        await asyncio.sleep(1)
+        results['answer'] = 42
+
+    @background_tasks.await_on_shutdown
+    async def backup() -> None:
+        await asyncio.sleep(1)
+        # async with aiofiles.open('backup.json', 'w') as f:
+        #     await f.write(f'{results["answer"]}')
+        # print('backup.json written', flush=True)
+
+    ui.label().bind_text_from(results, 'answer', lambda x: f'answer: {x}')
+    ui.button('Compute', on_click=lambda: background_tasks.create(compute()))
+    ui.button('Backup', on_click=lambda: background_tasks.create(backup()))
+
+
 doc.text('Custom Vue Components', '''
     You can create custom components by subclassing `ui.element` and implementing a corresponding Vue component.
     The ["Custom Vue components" example](https://github.com/zauberzeug/nicegui/tree/main/examples/custom_vue_component)