Browse Source

Revert "avoid another cyclic import"

This reverts commit 69f94afa7c8d7e43421cb9cc31dc1a599fb97692.
Falko Schindler 1 year ago
parent
commit
7b5074cc2b

+ 2 - 32
nicegui/app.py

@@ -1,14 +1,11 @@
-import inspect
-from contextlib import nullcontext
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Awaitable, Callable, Optional, Union
+from typing import Awaitable, Callable, Optional, Union
 
 
 from fastapi import FastAPI, HTTPException, Request
 from fastapi import FastAPI, HTTPException, Request
 from fastapi.responses import FileResponse, StreamingResponse
 from fastapi.responses import FileResponse, StreamingResponse
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
 
 
-from . import background_tasks, globals, helpers  # pylint: disable=redefined-builtin
-from .client import Client
+from . import globals, helpers  # pylint: disable=redefined-builtin
 from .native import Native
 from .native import Native
 from .observables import ObservableSet
 from .observables import ObservableSet
 from .storage import Storage
 from .storage import Storage
@@ -176,30 +173,3 @@ class App(FastAPI):
     def remove_route(self, path: str) -> None:
     def remove_route(self, path: str) -> None:
         """Remove routes with the given path."""
         """Remove routes with the given path."""
         self.routes[:] = [r for r in self.routes if getattr(r, 'path', None) != path]
         self.routes[:] = [r for r in self.routes if getattr(r, 'path', None) != path]
-
-    def handle_exception(self, exception: Exception) -> None:
-        """Handle an exception by invoking all registered exception handlers."""
-        for handler in self.exception_handlers:
-            result = handler() if not inspect.signature(handler).parameters else handler(exception)
-            if helpers.is_coroutine_function(handler):
-                background_tasks.create(result)
-
-    def safe_invoke(self, func: Union[Callable[..., Any], Awaitable], client: Optional[Client] = None) -> None:
-        """Invoke the potentially async function in the client context and catch any exceptions."""
-        try:
-            if isinstance(func, Awaitable):
-                async def func_with_client():
-                    with client or nullcontext():
-                        await func
-                background_tasks.create(func_with_client())
-            else:
-                with client or nullcontext():
-                    result = func(client) if len(inspect.signature(
-                        func).parameters) == 1 and client is not None else func()
-                if helpers.is_coroutine_function(func):
-                    async def result_with_client():
-                        with client or nullcontext():
-                            await result
-                    background_tasks.create(result_with_client())
-        except Exception as e:
-            self.handle_exception(e)

+ 1 - 1
nicegui/background_tasks.py

@@ -57,4 +57,4 @@ def _handle_task_result(task: asyncio.Task) -> None:
     except asyncio.CancelledError:
     except asyncio.CancelledError:
         pass
         pass
     except Exception as e:
     except Exception as e:
-        globals.app.handle_exception(e)
+        globals.handle_exception(e)

+ 2 - 2
nicegui/events.py

@@ -440,10 +440,10 @@ def handle_event(handler: Optional[Callable[..., Any]], arguments: EventArgument
                     try:
                     try:
                         await result
                         await result
                     except Exception as e:
                     except Exception as e:
-                        globals.app.handle_exception(e)
+                        globals.handle_exception(e)
             if globals.loop and globals.loop.is_running():
             if globals.loop and globals.loop.is_running():
                 background_tasks.create(wait_for_result(), name=str(handler))
                 background_tasks.create(wait_for_result(), name=str(handler))
             else:
             else:
                 globals.app.on_startup(wait_for_result())
                 globals.app.on_startup(wait_for_result())
     except Exception as e:
     except Exception as e:
-        globals.app.handle_exception(e)
+        globals.handle_exception(e)

+ 2 - 2
nicegui/functions/timer.py

@@ -81,7 +81,7 @@ class Timer:
                     except asyncio.CancelledError:
                     except asyncio.CancelledError:
                         break
                         break
                     except Exception as e:
                     except Exception as e:
-                        globals.app.handle_exception(e)
+                        globals.handle_exception(e)
                         await asyncio.sleep(self.interval)
                         await asyncio.sleep(self.interval)
         finally:
         finally:
             self._cleanup()
             self._cleanup()
@@ -93,7 +93,7 @@ class Timer:
             if helpers.is_coroutine_function(self.callback):
             if helpers.is_coroutine_function(self.callback):
                 await result
                 await result
         except Exception as e:
         except Exception as e:
-            globals.app.handle_exception(e)
+            globals.handle_exception(e)
 
 
     async def _connected(self, timeout: float = 60.0) -> bool:
     async def _connected(self, timeout: float = 60.0) -> bool:
         """Wait for the client connection before the timer callback can be allowed to manipulate the state.
         """Wait for the client connection before the timer callback can be allowed to manipulate the state.

+ 12 - 0
nicegui/globals.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
+import inspect
 import logging
 import logging
 import os
 import os
 from contextlib import contextmanager
 from contextlib import contextmanager
@@ -11,6 +12,9 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterator, List
 from socketio import AsyncServer
 from socketio import AsyncServer
 from uvicorn import Server
 from uvicorn import Server
 
 
+from . import background_tasks
+from .helpers import is_coroutine_function
+
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .air import Air
     from .air import Air
     from .app import App
     from .app import App
@@ -114,3 +118,11 @@ def socket_id(id_: str) -> Iterator[None]:
     _socket_id = id_
     _socket_id = id_
     yield
     yield
     _socket_id = None
     _socket_id = None
+
+
+def handle_exception(exception: Exception) -> None:
+    """Handle an exception by invoking all registered exception handlers."""
+    for handler in exception_handlers:
+        result = handler() if not inspect.signature(handler).parameters else handler(exception)
+        if is_coroutine_function(handler):
+            background_tasks.create(result)

+ 28 - 1
nicegui/helpers.py

@@ -3,18 +3,25 @@ from __future__ import annotations
 import asyncio
 import asyncio
 import functools
 import functools
 import hashlib
 import hashlib
+import inspect
 import mimetypes
 import mimetypes
 import socket
 import socket
 import sys
 import sys
 import threading
 import threading
 import time
 import time
 import webbrowser
 import webbrowser
+from contextlib import nullcontext
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Generator, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generator, Optional, Tuple, Union
 
 
 from fastapi import Request
 from fastapi import Request
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
 
 
+from . import background_tasks, globals  # pylint: disable=redefined-builtin
+
+if TYPE_CHECKING:
+    from .client import Client
+
 mimetypes.init()
 mimetypes.init()
 
 
 
 
@@ -51,6 +58,26 @@ def hash_file_path(path: Path) -> str:
     return hashlib.sha256(path.as_posix().encode()).hexdigest()[:32]
     return hashlib.sha256(path.as_posix().encode()).hexdigest()[:32]
 
 
 
 
+def safe_invoke(func: Union[Callable[..., Any], Awaitable], client: Optional[Client] = None) -> None:
+    """Invoke the potentially async function in the client context and catch any exceptions."""
+    try:
+        if isinstance(func, Awaitable):
+            async def func_with_client():
+                with client or nullcontext():
+                    await func
+            background_tasks.create(func_with_client())
+        else:
+            with client or nullcontext():
+                result = func(client) if len(inspect.signature(func).parameters) == 1 and client is not None else func()
+            if is_coroutine_function(func):
+                async def result_with_client():
+                    with client or nullcontext():
+                        await result
+                background_tasks.create(result_with_client())
+    except Exception as e:
+        globals.handle_exception(e)
+
+
 def is_port_open(host: str, port: int) -> bool:
 def is_port_open(host: str, port: int) -> bool:
     """Check if the port is open by checking if a TCP connection can be established."""
     """Check if the port is open by checking if a TCP connection can be established."""
     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

+ 7 - 7
nicegui/nicegui.py

@@ -17,7 +17,7 @@ from .app import App
 from .client import Client
 from .client import Client
 from .dependencies import js_components, libraries
 from .dependencies import js_components, libraries
 from .error import error_content
 from .error import error_content
-from .helpers import is_file
+from .helpers import is_file, safe_invoke
 from .json import NiceGUIJSONResponse
 from .json import NiceGUIJSONResponse
 from .middlewares import RedirectWithPrefixMiddleware
 from .middlewares import RedirectWithPrefixMiddleware
 from .page import page
 from .page import page
@@ -94,7 +94,7 @@ def handle_startup(with_welcome_message: bool = True) -> None:
     globals.loop = asyncio.get_running_loop()
     globals.loop = asyncio.get_running_loop()
     with globals.index_client:
     with globals.index_client:
         for t in globals.startup_handlers:
         for t in globals.startup_handlers:
-            app.safe_invoke(t)
+            safe_invoke(t)
     background_tasks.create(binding.refresh_loop(), name='refresh bindings')
     background_tasks.create(binding.refresh_loop(), name='refresh bindings')
     background_tasks.create(outbox.loop(), name='send outbox')
     background_tasks.create(outbox.loop(), name='send outbox')
     background_tasks.create(prune_clients(), name='prune clients')
     background_tasks.create(prune_clients(), name='prune clients')
@@ -114,7 +114,7 @@ async def handle_shutdown() -> None:
     globals.state = globals.State.STOPPING
     globals.state = globals.State.STOPPING
     with globals.index_client:
     with globals.index_client:
         for t in globals.shutdown_handlers:
         for t in globals.shutdown_handlers:
-            app.safe_invoke(t)
+            safe_invoke(t)
     run_executor.tear_down()
     run_executor.tear_down()
     globals.state = globals.State.STOPPED
     globals.state = globals.State.STOPPED
     if globals.air:
     if globals.air:
@@ -154,9 +154,9 @@ def handle_handshake(client: Client) -> None:
         client.disconnect_task.cancel()
         client.disconnect_task.cancel()
         client.disconnect_task = None
         client.disconnect_task = None
     for t in client.connect_handlers:
     for t in client.connect_handlers:
-        app.safe_invoke(t, client)
+        safe_invoke(t, client)
     for t in globals.connect_handlers:
     for t in globals.connect_handlers:
-        app.safe_invoke(t, client)
+        safe_invoke(t, client)
 
 
 
 
 @sio.on('disconnect')
 @sio.on('disconnect')
@@ -176,9 +176,9 @@ async def handle_disconnect(client: Client) -> None:
     if not client.shared:
     if not client.shared:
         _delete_client(client.id)
         _delete_client(client.id)
     for t in client.disconnect_handlers:
     for t in client.disconnect_handlers:
-        app.safe_invoke(t, client)
+        safe_invoke(t, client)
     for t in globals.disconnect_handlers:
     for t in globals.disconnect_handlers:
-        app.safe_invoke(t, client)
+        safe_invoke(t, client)
 
 
 
 
 @sio.on('event')
 @sio.on('event')

+ 2 - 2
nicegui/outbox.py

@@ -65,9 +65,9 @@ async def loop() -> None:
                 try:
                 try:
                     await coro
                     await coro
                 except Exception as e:
                 except Exception as e:
-                    globals.app.handle_exception(e)
+                    globals.handle_exception(e)
         except Exception as e:
         except Exception as e:
-            globals.app.handle_exception(e)
+            globals.handle_exception(e)
             await asyncio.sleep(0.1)
             await asyncio.sleep(0.1)