Bläddra i källkod

ui.on_connect, ui.on_disconnect fire for all clients

Rodja Trappe 2 år sedan
förälder
incheckning
f4cf582343
6 ändrade filer med 58 tillägg och 20 borttagningar
  1. 2 2
      nicegui/functions/lifecycle.py
  2. 2 0
      nicegui/globals.py
  3. 3 3
      nicegui/helpers.py
  4. 4 0
      nicegui/nicegui.py
  5. 37 6
      tests/test_lifecycle.py
  6. 10 9
      website/reference.py

+ 2 - 2
nicegui/functions/lifecycle.py

@@ -4,11 +4,11 @@ from .. import globals
 
 
 def on_connect(handler: Union[Callable, Awaitable]) -> None:
-    globals.get_client().connect_handlers.append(handler)
+    globals.connect_handlers.append(handler)
 
 
 def on_disconnect(handler: Union[Callable, Awaitable]) -> None:
-    globals.get_client().disconnect_handlers.append(handler)
+    globals.disconnect_handlers.append(handler)
 
 
 def on_startup(handler: Union[Callable, Awaitable]) -> None:

+ 2 - 0
nicegui/globals.py

@@ -48,6 +48,8 @@ tasks: List[asyncio.tasks.Task] = []
 
 startup_handlers: List[Union[Callable, Awaitable]] = []
 shutdown_handlers: List[Union[Callable, Awaitable]] = []
+connect_handlers: List[Union[Callable, Awaitable]] = []
+disconnect_handlers: List[Union[Callable, Awaitable]] = []
 
 
 def get_task_id() -> int:

+ 3 - 3
nicegui/helpers.py

@@ -1,7 +1,7 @@
 import asyncio
 import functools
 from contextlib import nullcontext
-from typing import Any, Awaitable, Callable, Optional, Union
+from typing import Any, Awaitable, Callable, List, Optional, Union
 
 from . import globals
 from .client import Client
@@ -14,7 +14,7 @@ def is_coroutine(object: Any) -> bool:
     return asyncio.iscoroutinefunction(object)
 
 
-def safe_invoke(func: Union[Callable, Awaitable], client: Optional[Client] = None) -> None:
+def safe_invoke(func: Union[Callable, Awaitable], client: Optional[Client] = None, *args: List[Any]) -> None:
     try:
         if isinstance(func, Awaitable):
             async def func_with_client():
@@ -23,7 +23,7 @@ def safe_invoke(func: Union[Callable, Awaitable], client: Optional[Client] = Non
             create_task(func_with_client())
         else:
             with client or nullcontext():
-                result = func()
+                result = func(*args)
             if isinstance(result, Awaitable):
                 async def result_with_client():
                     with client or nullcontext():

+ 4 - 0
nicegui/nicegui.py

@@ -94,6 +94,8 @@ async def handle_handshake(sid: str) -> bool:
     sio.enter_room(sid, client.id)
     for t in client.connect_handlers:
         safe_invoke(t, client)
+    for t in globals.connect_handlers:
+        safe_invoke(t, client, client)
     return True
 
 
@@ -106,6 +108,8 @@ async def handle_disconnect(sid: str) -> None:
         delete_client(client.id)
     for t in client.disconnect_handlers:
         safe_invoke(t, client)
+    for t in globals.disconnect_handlers:
+        safe_invoke(t, client, client)
 
 
 @sio.on('event')

+ 37 - 6
tests/test_lifecycle.py

@@ -1,21 +1,52 @@
-from nicegui import ui
+from typing import List
+
+from nicegui import Client, ui
 
 from .screen import Screen
 
 
-def test_adding_elements_during_onconnect(screen: Screen):
-    ui.label('Label 1')
-    ui.on_connect(lambda: ui.label('Label 2'))
+def test_adding_elements_during_onconnect_on_auto_index_page(screen: Screen):
+    connections = []
+    ui.label('Adding labels on_connect')
+    ui.on_connect(lambda _: connections.append(ui.label(f'new connection {len(connections)}')))
 
     screen.open('/')
-    screen.should_contain('Label 2')
+    screen.should_contain('new connection 0')
+    screen.open('/')
+    screen.should_contain('new connection 0')
+    screen.should_contain('new connection 1')
+    screen.open('/')
+    screen.should_contain('new connection 0')
+    screen.should_contain('new connection 1')
+    screen.should_contain('new connection 2')
 
 
 def test_async_connect_handler(screen: Screen):
-    async def run_js():
+    async def run_js(client: Client):
         result.text = await ui.run_javascript('41 + 1')
     result = ui.label()
     ui.on_connect(run_js)
 
     screen.open('/')
     screen.should_contain('42')
+
+
+def test_connect_disconnect_is_called_for_each_client(screen: Screen):
+    events: List[str] = []
+
+    @ui.page('/')
+    def page(client: Client):
+        ui.label(f'client id: {client.id}')
+    ui.on_connect(lambda c: events.append(f'|connect {c.id}|'))
+    ui.on_disconnect(lambda c: events.append(f'|disconnect {c.id}|'))
+
+    screen.open('/')
+    screen.open('/')
+    screen.open('/')
+    screen.wait(0.1)
+    assert len(events) == 5
+    assert events[0].startswith('|connect ')
+    assert events[1].startswith('|disconnect ')
+    assert events[2].startswith('|connect ')
+    assert events[3].startswith('|disconnect ')
+    assert events[4].startswith('|connect ')

+ 10 - 9
website/reference.py

@@ -486,22 +486,23 @@ You can run a function or coroutine as a parallel task by passing it to one of t
 
 - `ui.on_startup`: Called when NiceGUI is started or restarted.
 - `ui.on_shutdown`: Called when NiceGUI is shut down or restarted.
-- `ui.on_connect`: Called when a client connects to NiceGUI. (Optional argument: Starlette request)
-- `ui.on_disconnect`: Called when a client disconnects from NiceGUI. (Optional argument: socket)
+- `ui.on_connect`: Called for each client which connects. (nicegui.Client is passed as argument)
+- `ui.on_disconnect`: Called for each client which disconnects. (nicegui.Client is passed as argument)
 
 When NiceGUI is shut down or restarted, the startup tasks will be automatically canceled.
 ''', immediate=True)
     def lifecycle_example():
-        import asyncio
+        from nicegui import Client
 
-        l = ui.label()
+        async def increment(client: Client):
+            counter.value += 1
 
-        async def countdown():
-            for i in [5, 4, 3, 2, 1, 0]:
-                l.text = f'{i}...' if i else 'Take-off!'
-                await asyncio.sleep(1)
+        async def decrement(client: Client):
+            counter.value -= 1
 
-        ui.on_connect(countdown)
+        counter = ui.number('connections', value=0).props('readonly').classes('w-24')
+        ui.on_connect(increment)
+        ui.on_disconnect(decrement)
 
     @example(ui.timer)
     def timer_example():