Selaa lähdekoodia

replace client stack with slot stack

Falko Schindler 2 vuotta sitten
vanhempi
säilyke
c326e1c808

+ 4 - 6
nicegui/async_updater.py

@@ -1,14 +1,12 @@
-from typing import TYPE_CHECKING, Any, Coroutine, Generator
+from typing import Any, Coroutine, Generator
 
-if TYPE_CHECKING:
-    from .client import Client
+from . import globals
 
 
 class AsyncUpdater:
 
-    def __init__(self, coro: Coroutine, client: 'Client') -> None:
+    def __init__(self, coro: Coroutine) -> None:
         self.coro = coro
-        self.client = client
 
     def __await__(self) -> Generator[Any, None, Any]:
         coro_iter = self.coro.__await__()
@@ -28,5 +26,5 @@ class AsyncUpdater:
                 send, message = iter_throw, err
 
     def lazy_update(self) -> None:
-        for slot in self.client.slot_stack:
+        for slot in globals.get_slot_stack():
             slot.lazy_update()

+ 11 - 14
nicegui/client.py

@@ -3,15 +3,13 @@ import json
 import time
 import uuid
 from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
 
 from fastapi.responses import HTMLResponse
 
 from . import globals, ui, vue
-from .async_updater import AsyncUpdater
 from .element import Element
 from .favicon import get_favicon_url
-from .slot import Slot
 from .task_logger import create_task
 
 if TYPE_CHECKING:
@@ -22,22 +20,20 @@ TEMPLATE = (Path(__file__).parent / 'templates' / 'index.html').read_text()
 
 class Client:
 
-    def __init__(self, page: 'page') -> None:
+    def __init__(self, page: 'page', *, shared: bool = False) -> None:
         self.id = globals.next_client_id
         globals.next_client_id += 1
         globals.clients[self.id] = self
 
         self.elements: Dict[str, Element] = {}
         self.next_element_id: int = 0
-        self.slot_stack: List[Slot] = []
         self.is_waiting_for_handshake: bool = False
         self.environ: Optional[Dict[str, Any]] = None
+        self.shared = shared
 
-        globals.get_client_stack().append(self)
-        with Element('q-layout').props('view="HHH LpR FFF"') as self.layout:
+        with Element('q-layout', _client=self).props('view="HHH LpR FFF"') as self.layout:
             with Element('q-page-container'):
                 self.content = Element('div').classes('q-pa-md column items-start gap-4')
-        globals.get_client_stack().pop()
 
         self.waiting_javascript_commands: Dict[str, str] = {}
 
@@ -51,16 +47,11 @@ class Client:
         return self.environ.get('REMOTE_ADDR') if self.environ else None
 
     def __enter__(self):
-        globals.get_client_stack().append(self)
         self.content.__enter__()
         return self
 
     def __exit__(self, *_):
         self.content.__exit__()
-        globals.get_client_stack().pop()
-
-    def watch_asyncs(self, coro: Coroutine) -> AsyncUpdater:
-        return AsyncUpdater(coro, self)
 
     def build_response(self) -> HTMLResponse:
         vue_html, vue_styles, vue_scripts = vue.generate_vue_content()
@@ -110,10 +101,16 @@ class Client:
         create_task(globals.sio.emit('open', path, room=str(self.id)))
 
 
+class IndexClient(Client):
+
+    def __init__(self, page: 'page') -> None:
+        super().__init__(page, shared=True)
+
+
 class ErrorClient(Client):
 
     def __init__(self, page: 'page') -> None:
-        super().__init__(page)
+        super().__init__(page, shared=True)
         with self:
             with ui.column().classes('w-full py-20 items-center gap-0'):
                 ui.icon('☹').classes('text-8xl py-5') \

+ 9 - 5
nicegui/element.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import shlex
 from abc import ABC
 from copy import deepcopy
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
 
 from . import binding, globals
 from .elements.mixins.visibility import Visibility
@@ -11,12 +11,15 @@ from .event_listener import EventListener
 from .slot import Slot
 from .task_logger import create_task
 
+if TYPE_CHECKING:
+    from .client import Client
+
 
 class Element(ABC, Visibility):
 
-    def __init__(self, tag: str) -> None:
+    def __init__(self, tag: str, *, _client: Optional[Client] = None) -> None:
         super().__init__()
-        self.client = globals.get_client()
+        self.client = _client or globals.get_client()
         self.id = self.client.next_element_id
         self.client.next_element_id += 1
         self.tag = tag
@@ -30,8 +33,9 @@ class Element(ABC, Visibility):
 
         self.client.elements[self.id] = self
         self.parent_slot: Optional[Slot] = None
-        if self.client.slot_stack:
-            self.parent_slot = self.client.slot_stack[-1]
+        slot_stack = globals.get_slot_stack()
+        if slot_stack:
+            self.parent_slot = slot_stack[-1]
             self.parent_slot.children.append(self)
 
     def add_slot(self, name: str) -> Slot:

+ 2 - 1
nicegui/elements/menu.py

@@ -1,5 +1,6 @@
 from typing import Callable, Optional
 
+from .. import globals
 from ..events import ClickEventArguments, handle_event
 from .mixins.text_element import TextElement
 from .mixins.value_element import ValueElement
@@ -36,7 +37,7 @@ class MenuItem(TextElement):
         :param auto_close: whether the menu should be closed after a click event (default: `True`)
         """
         super().__init__(tag='q-item', text=text)
-        self.menu: Menu = self.client.slot_stack[-1].parent
+        self.menu: Menu = globals.get_slot().parent
         self._props['clickable'] = True
 
         def handle_click(_) -> None:

+ 1 - 1
nicegui/elements/scene_object3d.py

@@ -15,7 +15,7 @@ class Object3D:
         self.type = type
         self.id = str(uuid.uuid4())
         self.name: Optional[str] = None
-        self.scene: 'Scene' = globals.get_client().slot_stack[-1].parent
+        self.scene: 'Scene' = globals.get_slot().parent
         self.scene.objects[self.id] = self
         self.parent: Object3D = self.scene.stack[-1]
         self.args: List = list(args)

+ 9 - 7
nicegui/events.py

@@ -4,6 +4,7 @@ from inspect import signature
 from typing import TYPE_CHECKING, Any, Callable, List, Optional
 
 from . import globals
+from .async_updater import AsyncUpdater
 from .client import Client
 from .helpers import is_coroutine
 from .lifecycle import on_startup
@@ -250,12 +251,13 @@ def handle_event(handler: Optional[Callable], arguments: EventArguments) -> None
         no_arguments = not signature(handler).parameters
         with arguments.sender.parent_slot:
             result = handler() if no_arguments else handler(arguments)
-            if is_coroutine(handler):
-                async def wait_for_result():
-                    await arguments.sender.client.watch_asyncs(result)
-                if globals.loop and globals.loop.is_running():
-                    create_task(wait_for_result(), name=str(handler))
-                else:
-                    on_startup(None, wait_for_result())
+        if is_coroutine(handler):
+            async def wait_for_result():
+                with arguments.sender.parent_slot:
+                    await AsyncUpdater(result)
+            if globals.loop and globals.loop.is_running():
+                create_task(wait_for_result(), name=str(handler))
+            else:
+                on_startup(None, wait_for_result())
     except Exception:
         traceback.print_exc()

+ 9 - 8
nicegui/functions/timer.py

@@ -4,8 +4,8 @@ import traceback
 from typing import Callable
 
 from .. import globals
+from ..async_updater import AsyncUpdater
 from ..binding import BindableProperty
-from ..client import Client
 from ..helpers import is_coroutine
 from ..lifecycle import on_startup
 from ..task_logger import create_task
@@ -30,6 +30,7 @@ class Timer:
         self.interval = interval
         self.callback = callback
         self.active = active
+        self.slot = globals.get_slot()
 
         coroutine = self._run_once if once else self._run_in_loop
         if globals.state == globals.State.STARTED:
@@ -38,19 +39,19 @@ class Timer:
             on_startup(coroutine)
 
     async def _run_once(self) -> None:
-        with globals.get_client() as client, client.slot_stack[-1]:
+        with self.slot:
             await asyncio.sleep(self.interval)
-            await self._invoke_callback(client)
+            await self._invoke_callback()
 
     async def _run_in_loop(self) -> None:
-        with globals.get_client() as client, client.slot_stack[-1]:
+        with self.slot:
             while True:
-                if client.id not in globals.clients:
+                if self.slot.parent.client.id not in globals.clients:
                     return
                 try:
                     start = time.time()
                     if self.active:
-                        await self._invoke_callback(client)
+                        await self._invoke_callback()
                     dt = time.time() - start
                     await asyncio.sleep(self.interval - dt)
                 except asyncio.CancelledError:
@@ -59,10 +60,10 @@ class Timer:
                     traceback.print_exc()
                     await asyncio.sleep(self.interval)
 
-    async def _invoke_callback(self, client: Client) -> None:
+    async def _invoke_callback(self) -> None:
         try:
             result = self.callback()
             if is_coroutine(self.callback):
-                await client.watch_asyncs(result)
+                await AsyncUpdater(result)
         except Exception:
             traceback.print_exc()

+ 15 - 7
nicegui/globals.py

@@ -9,6 +9,7 @@ from uvicorn import Server
 
 if TYPE_CHECKING:
     from .client import Client
+    from .slot import Slot
 
 
 class State(Enum):
@@ -34,7 +35,7 @@ dark: Optional[bool]
 binding_refresh_interval: float
 excludes: List[str]
 
-client_stacks: Dict[int, List['Client']] = {}
+slot_stacks: Dict[int, List['Slot']] = {}
 clients: Dict[int, 'Client'] = {}
 next_client_id: int = 0
 index_client: 'Client' = ...
@@ -51,15 +52,22 @@ shutdown_handlers: List[Union[Callable, Awaitable]] = []
 
 
 def get_task_id() -> int:
-    return id(asyncio.current_task()) if loop and loop.is_running() else 0
+    try:
+        return id(asyncio.current_task())
+    except RuntimeError:
+        return 0
 
 
-def get_client_stack() -> List['Client']:
+def get_slot_stack() -> List['Slot']:
     task_id = get_task_id()
-    if task_id not in client_stacks:
-        client_stacks[task_id] = [index_client]
-    return client_stacks[task_id]
+    if task_id not in slot_stacks:
+        slot_stacks[task_id] = []
+    return slot_stacks[task_id]
+
+
+def get_slot() -> 'Slot':
+    return get_slot_stack()[-1]
 
 
 def get_client() -> 'Client':
-    return get_client_stack()[-1]
+    return get_slot().parent.client

+ 3 - 3
nicegui/nicegui.py

@@ -10,7 +10,7 @@ from fastapi.staticfiles import StaticFiles
 from fastapi_socketio import SocketManager
 
 from . import binding, globals, vue
-from .client import Client, ErrorClient
+from .client import Client, ErrorClient, IndexClient
 from .favicon import create_favicon_routes
 from .helpers import safe_invoke
 from .page import page
@@ -22,8 +22,8 @@ globals.sio = sio = SocketManager(app=app)._sio
 app.add_middleware(GZipMiddleware)
 app.mount("/static", StaticFiles(directory=Path(__file__).parent / 'static'), name='static')
 
-globals.index_client = Client(page('/')).__enter__()
 globals.error_client = ErrorClient(page(''))
+globals.index_client = IndexClient(page('/')).__enter__()
 
 
 @app.get('/')
@@ -80,7 +80,7 @@ async def handle_connect(sid: str, _) -> None:
 @sio.on('disconnect')
 async def handle_disconnect(sid: str) -> None:
     client = get_client(sid)
-    if client.id != 0:
+    if not client.shared:
         del globals.clients[client.id]
 
 

+ 4 - 2
nicegui/slot.py

@@ -1,5 +1,7 @@
 from typing import TYPE_CHECKING, List
 
+from . import globals
+
 if TYPE_CHECKING:
     from .element import Element
 
@@ -14,11 +16,11 @@ class Slot:
 
     def __enter__(self):
         self.child_count = len(self.children)
-        self.parent.client.slot_stack.append(self)
+        globals.get_slot_stack().append(self)
         return self
 
     def __exit__(self, *_):
-        self.parent.client.slot_stack.pop()
+        globals.get_slot_stack().pop()
         self.lazy_update()
 
     def lazy_update(self) -> None: