Browse Source

move air instance out of globals.py; move client-related functions out of nicegui.py

Falko Schindler 1 year ago
parent
commit
fcdd731a93
6 changed files with 110 additions and 104 deletions
  1. 29 9
      nicegui/air.py
  2. 65 1
      nicegui/client.py
  3. 0 2
      nicegui/globals.py
  4. 10 79
      nicegui/nicegui.py
  5. 4 10
      nicegui/outbox.py
  6. 2 3
      nicegui/run.py

+ 29 - 9
nicegui/air.py

@@ -1,16 +1,15 @@
 import asyncio
 import gzip
 import re
-from typing import Any, Dict
+from typing import Any, Dict, Optional
 
 import httpx
 import socketio
-from socketio import AsyncClient
+import socketio.exceptions
 
 from . import background_tasks, globals  # pylint: disable=redefined-builtin
 from .client import Client
 from .logging import log
-from .nicegui import handle_disconnect, handle_event, handle_handshake, handle_javascript_response
 
 RELAY_HOST = 'https://on-air.nicegui.io/'
 
@@ -19,7 +18,7 @@ class Air:
 
     def __init__(self, token: str) -> None:
         self.token = token
-        self.relay = AsyncClient()
+        self.relay = socketio.AsyncClient()
         self.client = httpx.AsyncClient(app=globals.app)
         self.connecting = False
 
@@ -70,7 +69,7 @@ class Air:
             client = Client.instances[client_id]
             client.environ = data['environ']
             client.on_air = True
-            handle_handshake(client)
+            client.handle_handshake()
             return True
 
         @self.relay.on('client_disconnect')
@@ -78,8 +77,7 @@ class Air:
             client_id = data['client_id']
             if client_id not in Client.instances:
                 return
-            client = Client.instances[client_id]
-            client.disconnect_task = background_tasks.create(handle_disconnect(client))
+            Client.instances[client_id].handle_disconnect()
 
         @self.relay.on('event')
         def _handle_event(data: Dict[str, Any]) -> None:
@@ -89,7 +87,7 @@ class Air:
             client = Client.instances[client_id]
             if isinstance(data['msg']['args'], dict) and 'socket_id' in data['msg']['args']:
                 data['msg']['args']['socket_id'] = client_id  # HACK: translate socket_id of ui.scene's init event
-            handle_event(client, data['msg'])
+            client.handle_event(data['msg'])
 
         @self.relay.on('javascript_response')
         def _handle_javascript_response(data: Dict[str, Any]) -> None:
@@ -97,7 +95,7 @@ class Air:
             if client_id not in Client.instances:
                 return
             client = Client.instances[client_id]
-            handle_javascript_response(client, data['msg'])
+            client.handle_javascript_response(data['msg'])
 
         @self.relay.on('out_of_time')
         async def _handle_out_of_time() -> None:
@@ -143,3 +141,25 @@ class Air:
         """Emit a message to the NiceGUI On Air server."""
         if self.relay.connected:
             await self.relay.emit('forward', {'event': message_type, 'data': data, 'room': room})
+
+    @staticmethod
+    def is_air_target(target_id: str) -> bool:
+        """Whether the given target ID is an On Air client or a SocketIO room."""
+        if target_id in Client.instances:
+            return Client.instances[target_id].on_air
+        return target_id in globals.sio.manager.rooms
+
+
+instance: Optional[Air] = None
+
+
+def connect() -> None:
+    """Connect to the NiceGUI On Air server if there is an air instance."""
+    if instance:
+        background_tasks.create(instance.connect())
+
+
+def disconnect() -> None:
+    """Disconnect from the NiceGUI On Air server if there is an air instance."""
+    if instance:
+        background_tasks.create(instance.disconnect())

+ 65 - 1
nicegui/client.py

@@ -13,11 +13,12 @@ from fastapi.templating import Jinja2Templates
 
 from nicegui import json
 
-from . import binding, globals, outbox  # pylint: disable=redefined-builtin
+from . import background_tasks, binding, globals, outbox  # pylint: disable=redefined-builtin
 from .awaitable_response import AwaitableResponse
 from .dependencies import generate_resources
 from .element import Element
 from .favicon import get_favicon_url
+from .helpers import safe_invoke
 from .logging import log
 from .version import __version__
 
@@ -194,6 +195,43 @@ class Client:
         """Register a callback to be called when the client disconnects."""
         self.disconnect_handlers.append(handler)
 
+    def handle_handshake(self) -> None:
+        """Cancel pending disconnect task and invoke connect handlers."""
+        if self.disconnect_task:
+            self.disconnect_task.cancel()
+            self.disconnect_task = None
+        for t in self.connect_handlers:
+            safe_invoke(t, self)
+        for t in globals.app._connect_handlers:  # pylint: disable=protected-access
+            safe_invoke(t, self)
+
+    def handle_disconnect(self) -> None:
+        """Wait for the browser to reconnect; invoke disconnect handlers if it doesn't."""
+        async def handle_disconnect() -> None:
+            delay = self.page.reconnect_timeout if self.page.reconnect_timeout is not None else globals.reconnect_timeout
+            await asyncio.sleep(delay)
+            if not self.shared:
+                self.delete()
+            for t in self.disconnect_handlers:
+                safe_invoke(t, self)
+            for t in globals.app._disconnect_handlers:  # pylint: disable=protected-access
+                safe_invoke(t, self)
+        self.disconnect_task = background_tasks.create(handle_disconnect())
+
+    def handle_event(self, msg: Dict) -> None:
+        """Forward an event to the corresponding element."""
+        with self:
+            sender = self.elements.get(msg['id'])
+            if sender:
+                msg['args'] = [None if arg is None else json.loads(arg) for arg in msg.get('args', [])]
+                if len(msg['args']) == 1:
+                    msg['args'] = msg['args'][0]
+                sender._handle_event(msg)  # pylint: disable=protected-access
+
+    def handle_javascript_response(self, msg: Dict) -> None:
+        """Store the result of a JavaScript command."""
+        self.waiting_javascript_commands[msg['request_id']] = msg['result']
+
     def remove_elements(self, elements: Iterable[Element]) -> None:
         """Remove the given elements from the client."""
         binding.remove(elements, Element)
@@ -209,6 +247,15 @@ class Client:
         """Remove all elements from the client."""
         self.remove_elements(self.elements.values())
 
+    def delete(self) -> None:
+        """Delete a client and all its elements.
+
+        If the global clients dictionary does not contain the client, its elements are still removed and a KeyError is raised.
+        Normally this should never happen, but has been observed (see #1826).
+        """
+        self.remove_all_elements()
+        del Client.instances[self.id]
+
     @contextmanager
     def individual_target(self, socket_id: str) -> Iterator[None]:
         """Use individual socket ID while in this context.
@@ -218,3 +265,20 @@ class Client:
         self._temporary_socket_id = socket_id
         yield
         self._temporary_socket_id = None
+
+    @classmethod
+    async def prune_instances(cls) -> None:
+        """Prune stale clients in an endless loop."""
+        while True:
+            try:
+                stale_clients = [
+                    client
+                    for client in cls.instances.values()
+                    if not client.shared and not client.has_socket_connection and client.created < time.time() - 60.0
+                ]
+                for client in stale_clients:
+                    client.delete()
+            except Exception:
+                # NOTE: make sure the loop doesn't crash
+                log.exception('Error while pruning clients')
+            await asyncio.sleep(10)

+ 0 - 2
nicegui/globals.py

@@ -9,7 +9,6 @@ from socketio import AsyncServer
 from uvicorn import Server
 
 if TYPE_CHECKING:
-    from .air import Air
     from .app import App
     from .client import Client
     from .language import Language
@@ -33,7 +32,6 @@ reconnect_timeout: float
 tailwind: bool
 prod_js: bool
 endpoint_documentation: Literal['none', 'internal', 'page', 'all'] = 'none'
-air: Optional[Air] = None
 storage_path: Path = Path(os.environ.get('NICEGUI_STORAGE_PATH', '.nicegui')).resolve()
 socket_io_js_query_params: Dict = {}
 socket_io_js_extra_headers: Dict = {}

+ 10 - 79
nicegui/nicegui.py

@@ -1,6 +1,5 @@
 import asyncio
 import mimetypes
-import time
 import urllib.parse
 from pathlib import Path
 from typing import Dict
@@ -11,13 +10,13 @@ from fastapi.responses import FileResponse, Response
 from fastapi.staticfiles import StaticFiles
 from fastapi_socketio import SocketManager
 
-from . import (background_tasks, binding, favicon, globals, json, outbox,  # pylint: disable=redefined-builtin
+from . import (air, background_tasks, binding, favicon, globals, json, outbox,  # pylint: disable=redefined-builtin
                run_executor, welcome)
 from .app import App
 from .client import Client
 from .dependencies import js_components, libraries
 from .error import error_content
-from .helpers import is_file, safe_invoke
+from .helpers import is_file
 from .json import NiceGUIJSONResponse
 from .logging import log
 from .middlewares import RedirectWithPrefixMiddleware
@@ -94,13 +93,12 @@ def handle_startup(with_welcome_message: bool = True) -> None:
     globals.loop = asyncio.get_running_loop()
     globals.app.start()
     background_tasks.create(binding.refresh_loop(), name='refresh bindings')
-    background_tasks.create(outbox.loop(Client.instances), name='send outbox')
-    background_tasks.create(prune_clients(), name='prune clients')
+    background_tasks.create(outbox.loop(air.instance), name='send outbox')
+    background_tasks.create(Client.prune_instances(), name='prune clients')
     background_tasks.create(prune_slot_stacks(), name='prune slot stacks')
     if with_welcome_message:
         background_tasks.create(welcome.print_message())
-    if globals.air:
-        background_tasks.create(globals.air.connect())
+    air.connect()
 
 
 @app.on_event('shutdown')
@@ -110,8 +108,7 @@ async def handle_shutdown() -> None:
         app.native.main_window.signal_server_shutdown()
     globals.app.stop()
     run_executor.tear_down()
-    if globals.air:
-        await globals.air.disconnect()
+    air.disconnect()
 
 
 @app.exception_handler(404)
@@ -137,21 +134,10 @@ async def _on_handshake(sid: str, client_id: str) -> bool:
         return False
     client.environ = sio.get_environ(sid)
     await sio.enter_room(sid, client.id)
-    handle_handshake(client)
+    client.handle_handshake()
     return True
 
 
-def handle_handshake(client: Client) -> None:
-    """Cancel pending disconnect task and invoke connect handlers."""
-    if client.disconnect_task:
-        client.disconnect_task.cancel()
-        client.disconnect_task = None
-    for t in client.connect_handlers:
-        safe_invoke(t, client)
-    for t in app._connect_handlers:  # pylint: disable=protected-access
-        safe_invoke(t, client)
-
-
 @sio.on('disconnect')
 def _on_disconnect(sid: str) -> None:
     query_bytes: bytearray = sio.get_environ(sid)['asgi.scope']['query_string']
@@ -159,19 +145,7 @@ def _on_disconnect(sid: str) -> None:
     client_id = query['client_id'][0]
     client = Client.instances.get(client_id)
     if client:
-        client.disconnect_task = background_tasks.create(handle_disconnect(client))
-
-
-async def handle_disconnect(client: Client) -> None:
-    """Wait for the browser to reconnect; invoke disconnect handlers if it doesn't."""
-    delay = client.page.reconnect_timeout if client.page.reconnect_timeout is not None else globals.reconnect_timeout
-    await asyncio.sleep(delay)
-    if not client.shared:
-        _delete_client(client)
-    for t in client.disconnect_handlers:
-        safe_invoke(t, client)
-    for t in app._disconnect_handlers:  # pylint: disable=protected-access
-        safe_invoke(t, client)
+        client.handle_disconnect()
 
 
 @sio.on('event')
@@ -179,18 +153,7 @@ def _on_event(_: str, msg: Dict) -> None:
     client = Client.instances.get(msg['client_id'])
     if not client or not client.has_socket_connection:
         return
-    handle_event(client, msg)
-
-
-def handle_event(client: Client, msg: Dict) -> None:
-    """Forward an event to the corresponding element."""
-    with client:
-        sender = client.elements.get(msg['id'])
-        if sender:
-            msg['args'] = [None if arg is None else json.loads(arg) for arg in msg.get('args', [])]
-            if len(msg['args']) == 1:
-                msg['args'] = msg['args'][0]
-            sender._handle_event(msg)  # pylint: disable=protected-access
+    client.handle_event(msg)
 
 
 @sio.on('javascript_response')
@@ -198,29 +161,7 @@ def _on_javascript_response(_: str, msg: Dict) -> None:
     client = Client.instances.get(msg['client_id'])
     if not client:
         return
-    handle_javascript_response(client, msg)
-
-
-def handle_javascript_response(client: Client, msg: Dict) -> None:
-    """Forward a JavaScript response to the corresponding element."""
-    client.waiting_javascript_commands[msg['request_id']] = msg['result']
-
-
-async def prune_clients() -> None:
-    """Prune stale clients in an endless loop."""
-    while True:
-        try:
-            stale_clients = [
-                client
-                for client in Client.instances.values()
-                if not client.shared and not client.has_socket_connection and client.created < time.time() - 60.0
-            ]
-            for client in stale_clients:
-                _delete_client(client)
-        except Exception:
-            # NOTE: make sure the loop doesn't crash
-            log.exception('Error while pruning clients')
-        await asyncio.sleep(10)
+    client.handle_javascript_response(msg)
 
 
 async def prune_slot_stacks() -> None:
@@ -243,13 +184,3 @@ async def prune_slot_stacks() -> None:
             # NOTE: make sure the loop doesn't crash
             log.exception('Error while pruning slot stacks')
         await asyncio.sleep(10)
-
-
-def _delete_client(client: Client) -> None:
-    """Delete a client and all its elements.
-
-    If the global clients dictionary does not contain the client, its elements are still removed and a KeyError is raised.
-    Normally this should never happen, but has been observed (see #1826).
-    """
-    client.remove_all_elements()
-    del Client.instances[client.id]

+ 4 - 10
nicegui/outbox.py

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, DefaultDict, Deque, Dict, Optional, Tuple
 from . import globals  # pylint: disable=redefined-builtin
 
 if TYPE_CHECKING:
-    from .client import Client
+    from .air import Air
     from .element import Element
 
 ClientId = str
@@ -34,18 +34,12 @@ def enqueue_message(message_type: MessageType, data: Any, target_id: ClientId) -
     message_queue.append((target_id, message_type, data))
 
 
-async def loop(clients: Dict[str, Client]) -> None:
+async def loop(air: Optional[Air]) -> None:
     """Emit queued updates and messages in an endless loop."""
-    def is_target_on_air(target_id: str) -> bool:
-        if target_id in clients:
-            return clients[target_id].on_air
-        return target_id in globals.sio.manager.rooms
-
     async def emit(message_type: MessageType, data: Any, target_id: ClientId) -> None:
         await globals.sio.emit(message_type, data, room=target_id)
-        if is_target_on_air(target_id):
-            assert globals.air is not None
-            await globals.air.emit(message_type, data, room=target_id)
+        if air is not None and air.is_air_target(target_id):
+            await air.emit(message_type, data, room=target_id)
 
     while True:
         if not update_queue and not message_queue:

+ 2 - 3
nicegui/run.py

@@ -13,9 +13,8 @@ from uvicorn.supervisors import ChangeReload, Multiprocess
 
 from . import native_mode  # pylint: disable=redefined-builtin
 from . import storage  # pylint: disable=redefined-builtin
-from . import globals, helpers  # pylint: disable=redefined-builtin
+from . import air, globals, helpers  # pylint: disable=redefined-builtin
 from . import native as native_module
-from .air import Air
 from .client import Client
 from .language import Language
 from .logging import log
@@ -122,7 +121,7 @@ def run(*,
             route.include_in_schema = endpoint_documentation in {'page', 'all'}
 
     if on_air:
-        globals.air = Air('' if on_air is True else on_air)
+        air.instance = air.Air('' if on_air is True else on_air)
 
     if multiprocessing.current_process().name != 'MainProcess':
         return