Prechádzať zdrojové kódy

introduced delayed updates

Falko Schindler 1 rok pred
rodič
commit
188c81b5b2
2 zmenil súbory, kde vykonal 72 pridanie a 27 odobranie
  1. 58 26
      nicegui/outbox.py
  2. 14 1
      tests/test_element.py

+ 58 - 26
nicegui/outbox.py

@@ -3,38 +3,56 @@ from __future__ import annotations
 import asyncio
 import time
 from collections import defaultdict, deque
+from dataclasses import dataclass, field
 from typing import TYPE_CHECKING, Any, DefaultDict, Deque, Dict, List, Optional, Tuple
 
 from . import core
+from .dataclasses import KWONLY_SLOTS
 
 if TYPE_CHECKING:
     from .air import Air
     from .client import Client
     from .element import Element
 
+
+@dataclass(**KWONLY_SLOTS)
+class DelayedUpdate:
+    time: float = 0
+    data: Dict[ElementId, Optional[Dict]] = field(default_factory=dict)
+
+
+@dataclass(**KWONLY_SLOTS)
+class DelayedMessage:
+    time: float
+    target_id: str
+    message_type: str
+    data: Any
+
+
 ClientId = str
 ElementId = int
 MessageType = str
 Message = Tuple[ClientId, MessageType, Any]
 
-update_queue: DefaultDict[ClientId, Dict[ElementId, Optional[Element]]] = defaultdict(dict)
-message_queue: Deque[Message] = deque()
-message_delay: List[Tuple[float, Message]] = []
+waiting_updates: DefaultDict[ClientId, Dict[ElementId, Optional[Element]]] = defaultdict(dict)
+delayed_updates: DefaultDict[ClientId, DelayedUpdate] = defaultdict(DelayedUpdate)
+waiting_messages: Deque[Message] = deque()
+delayed_messages: List[DelayedMessage] = []
 
 
 def enqueue_update(element: Element) -> None:
     """Enqueue an update for the given element."""
-    update_queue[element.client.id][element.id] = element
+    waiting_updates[element.client.id][element.id] = element
 
 
 def enqueue_delete(element: Element) -> None:
     """Enqueue a deletion for the given element."""
-    update_queue[element.client.id][element.id] = None
+    waiting_updates[element.client.id][element.id] = None
 
 
 def enqueue_message(message_type: MessageType, data: Any, target_id: ClientId) -> None:
     """Enqueue a message for the given client."""
-    message_queue.append((target_id, message_type, data))
+    waiting_messages.append((target_id, message_type, data))
 
 
 async def loop(air: Optional[Air], clients: Dict[str, Client]) -> None:
@@ -47,40 +65,54 @@ async def loop(air: Optional[Air], clients: Dict[str, Client]) -> None:
     while True:
         await asyncio.sleep(0.01)
 
-        if not update_queue and not message_queue and not message_delay:
+        if not delayed_updates and not waiting_updates and not delayed_messages and not waiting_messages:
             continue
 
         coros = []
         try:
-            # process update_queue
-            for client_id, elements in update_queue.items():
+            # process delayed_updates
+            for client_id in list(delayed_updates):
+                update = delayed_updates[client_id]
+                client = clients.get(client_id)
+                if client is None or client.has_socket_connection:
+                    coros.append(emit('update', update.data, client_id))
+                    delayed_updates.pop(client_id)
+                elif time.time() > update.time + 3.0:
+                    delayed_updates.pop(client_id)
+
+            # process waiting_updates
+            for client_id, elements in waiting_updates.items():
                 data = {
                     element_id: None if element is None else element._to_dict()  # pylint: disable=protected-access
                     for element_id, element in elements.items()
                 }
-                coros.append(emit('update', data, client_id))
-            update_queue.clear()
-
-            # process message_queue
-            for target_id, message_type, data in message_queue:
-                client = clients.get(target_id)
+                client = clients.get(client_id)
                 if client is None or client.has_socket_connection:
-                    coros.append(emit(message_type, data, target_id))
+                    coros.append(emit('update', data, client_id))
                 else:
-                    message_delay.append((time.time(), (target_id, message_type, data)))
-            message_queue.clear()
+                    delayed_updates[client_id].time = time.time()
+                    delayed_updates[client_id].data.update(data)
+            waiting_updates.clear()
 
-            # process message_delay
-            indices = []
-            for i, (t, (target_id, message_type, data)) in enumerate(message_delay):
+            # process delayed_messages
+            for i, message in enumerate(list(delayed_messages)):
+                client = clients.get(message.target_id)
+                if client is None or client.has_socket_connection:
+                    coros.append(emit(message.message_type, message.data, message.target_id))
+                    delayed_messages.pop(i)
+                elif time.time() > message.time + 3.0:
+                    delayed_messages.pop(i)
+
+            # process waiting_messages
+            for target_id, message_type, data in waiting_messages:
                 client = clients.get(target_id)
                 if client is None or client.has_socket_connection:
                     coros.append(emit(message_type, data, target_id))
-                    indices.append(i)
-                elif time.time() > t + 3.0:
-                    indices.append(i)
-            for i in reversed(indices):
-                message_delay.pop(i)
+                else:
+                    message = DelayedMessage(time=time.time(),
+                                             target_id=target_id, message_type=message_type, data=data)
+                    delayed_messages.append(message)
+            waiting_messages.clear()
 
             # run coroutines
             for coro in coros:

+ 14 - 1
tests/test_element.py

@@ -1,7 +1,7 @@
 import pytest
 from selenium.webdriver.common.by import By
 
-from nicegui import ui
+from nicegui import background_tasks, ui
 from nicegui.testing import Screen
 
 
@@ -281,3 +281,16 @@ def test_bad_characters(screen: Screen):
 
     screen.open('/')
     screen.should_contain(r'& <test> ` ${foo}')
+
+
+def test_update_before_client_connection(screen: Screen):
+    @ui.page('/')
+    def page():
+        label = ui.label('Hello world!')
+
+        async def update():
+            label.text = 'Hello again!'
+        background_tasks.create(update())
+
+    screen.open('/')
+    screen.should_contain('Hello again!')