1
0
Masen Furer 2 долоо хоног өмнө
parent
commit
d8075bbcc6

+ 112 - 11
reflex/app.py

@@ -12,6 +12,7 @@ import inspect
 import io
 import json
 import sys
+import time
 import traceback
 from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
 from datetime import datetime
@@ -1520,12 +1521,32 @@ class App(MiddlewareMixin, LifespanMixin):
             yield state
             delta = await state._get_resolved_delta()
             if delta:
-                # When the state is modified reset dirty status and emit the delta to the frontend.
-                state._clean()
-                await self.event_namespace.emit_update(
-                    update=StateUpdate(delta=delta),
-                    sid=state.router.session.session_id,
+                synthetic_modify_event = Event(
+                    token=state.router.session.client_token,
+                    name="App.modify_state",
                 )
+                if not (
+                    updated_sid := await self.event_namespace._get_sid_with_timeout(
+                        event=synthetic_modify_event, warn_on_error=True
+                    )
+                ):
+                    # The client has disconnected, so we don't need to emit the update.
+                    return
+                update = StateUpdate(delta=delta)
+                await self.event_namespace.emit_update(update=update, sid=updated_sid)
+                if not self.event_namespace.token_to_sid.get(synthetic_modify_event.token) and (
+                    updated_sid := await self.event_namespace._get_sid_with_timeout(
+                        event=synthetic_modify_event, warn_on_error=True
+                    )
+                ):
+                    #breakpoint()
+                    # The client reconnected after failing to send, so try to send it again.
+                    await self.event_namespace.emit_update(
+                        update=update, sid=updated_sid
+                    )
+                if updated_sid:
+                    # Reset dirty status after the delta is sent to the frontend.
+                    state._clean()
 
     def _process_background(
         self, state: BaseState, event: Event
@@ -1544,6 +1565,27 @@ class App(MiddlewareMixin, LifespanMixin):
         if not handler.is_background:
             return None
 
+        started_at = datetime.now()
+
+        async def _get_sid_or_bail(
+            timeout: int | float = environment.REFLEX_CLIENT_DISCONNECT_TIMEOUT.get(),
+        ) -> str | None:
+            if self.event_namespace is None:
+                raise RuntimeError("App has not been initialized yet.")
+            if not (
+                sid := await self.event_namespace._get_sid_with_timeout(
+                    event, timeout, warn_on_error=False
+                )
+            ):
+                # If the sid is not found, skip sending the update and cancel task.
+                console.warn(
+                    f"Cannot send update from background task {handler.fn} (started at {started_at.isoformat()}): "
+                    f"No websocket associated with token {event.token} after {timeout}s.",
+                    dedupe=True,
+                )
+                task.cancel()
+            return sid
+
         async def _coro():
             """Coroutine to process the event and emit updates inside an asyncio.Task.
 
@@ -1560,16 +1602,37 @@ class App(MiddlewareMixin, LifespanMixin):
                 # Postprocess the event.
                 update = await self._postprocess(state, event, update)
 
+                if not (sid := await _get_sid_or_bail()):
+                    return
+
                 # Send the update to the client.
-                await self.event_namespace.emit_update(
-                    update=update,
-                    sid=state.router.session.session_id,
-                )
+                await self.event_namespace.emit_update(update=update, sid=sid)
+
+                # Check if the client is still connected and attempt to resend if it wasn't
+                if not self.event_namespace.token_to_sid.get(event.token) and (
+                    sid := await _get_sid_or_bail()
+                ):
+                    #breakpoint()
+                    print(f"Last attempt failed, trying again {update}")
+                    # Second try to send the update.
+                    await self.event_namespace.emit_update(update=update, sid=sid)
+
+        def finish(t: asyncio.Task[None]):
+            """Finish the task and remove it from the background tasks.
+
+            Args:
+                t: The task to finish.
+            """
+            # Remove the task from the background tasks.
+            self._background_tasks.discard(t)
+
+            # Raise for errors.
+            t.result()
 
         task = asyncio.create_task(_coro())
         self._background_tasks.add(task)
         # Clean up task from background_tasks set when complete.
-        task.add_done_callback(self._background_tasks.discard)
+        task.add_done_callback(finish)
         return task
 
     def _validate_exception_handlers(self):
@@ -1975,6 +2038,7 @@ class EventNamespace(AsyncNamespace):
         """
         disconnect_token = self.sid_to_token.pop(sid, None)
         if disconnect_token:
+            print(f"Disconnect {disconnect_token}")
             self.token_to_sid.pop(disconnect_token, None)
 
     async def emit_update(self, update: StateUpdate, sid: str) -> None:
@@ -1985,10 +2049,32 @@ class EventNamespace(AsyncNamespace):
             sid: The Socket.IO session id.
         """
         # Creating a task prevents the update from being blocked behind other coroutines.
+        print(f"emit_update: {update=} {sid=}")
         await asyncio.create_task(
             self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
         )
 
+    async def _get_sid_with_timeout(
+        self,
+        event: Event,
+        timeout: int | float | None = None,
+        warn_on_error: bool = False,
+    ) -> str | None:
+        """Wait up to timeout seconds for the sid to be available, otherwise cancel the task."""
+        if timeout is None:
+            timeout = environment.REFLEX_CLIENT_DISCONNECT_TIMEOUT.get()
+        deadline = time.time() + timeout
+
+        # Find the latest sid for the client_token.
+        while not (sid := self.token_to_sid.get(event.token)) and time.time() < deadline:
+            await asyncio.sleep(0.5)
+        if not sid and warn_on_error:
+            console.warn(
+                f"Cannot send update from event {event.name}: "
+                f"No websocket associated with token {event.token} after {timeout}s."
+            )
+        return sid
+
     async def on_event(self, sid: str, data: Any):
         """Event for receiving front-end websocket events.
 
@@ -2053,8 +2139,23 @@ class EventNamespace(AsyncNamespace):
 
         # Process the events.
         async for update in process(self.app, event, sid, headers, client_ip):
+            if not (
+                updated_sid := await self._get_sid_with_timeout(
+                    event, warn_on_error=True
+                )
+            ):
+                # The client has disconnected, so we don't need to emit the update.
+                return
             # Emit the update from processing the event.
-            await self.emit_update(update=update, sid=sid)
+            await self.emit_update(update=update, sid=updated_sid)
+            if not self.token_to_sid.get(event.token) and (
+                updated_sid := await self._get_sid_with_timeout(
+                    event=event, warn_on_error=True
+                )
+            ):
+                #breakpoint()
+                # The client reconnected after failing to send, so try to send it again.
+                await self.emit_update(update=update, sid=updated_sid)
 
     async def on_ping(self, sid: str):
         """Event for testing the API endpoint.

+ 3 - 0
reflex/config.py

@@ -748,6 +748,9 @@ class EnvironmentVariables:
     # The timeout to wait for a pong from the websocket server in seconds.
     REFLEX_SOCKET_TIMEOUT: EnvVar[int] = env_var(constants.Ping.TIMEOUT)
 
+    # The maximum time to wait before dropping updates or cancelling a background task associated with a disconnected client.
+    REFLEX_CLIENT_DISCONNECT_TIMEOUT: EnvVar[int] = env_var(5)
+
 
 environment = EnvironmentVariables()
 

+ 43 - 0
tests/integration/test_background_task.py

@@ -405,3 +405,46 @@ def test_yield_in_async_with_self(
 
     yield_in_async_with_self_button.click()
     assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
+
+
+
+def test_background_task_refresh(
+    background_task: AppHarness,
+    driver: WebDriver,
+    token: str,
+):
+    """Test that background tasks keep working when the page is refreshed.
+
+    Args:
+        background_task: harness for BackgroundTask app.
+        driver: WebDriver instance.
+        token: The token for the connected client.
+    """
+    assert background_task.app_instance is not None
+
+    racy_increment_button = driver.find_element(By.ID, "racy-increment")
+    driver.find_element(By.ID, "reset")
+
+    # get a reference to the iterations input
+    iterations_input = driver.find_element(By.ID, "iterations")
+
+    # kick off background tasks
+    iterations_input.clear()
+    iterations_input.send_keys("50")
+    racy_increment_button.click()
+
+    # Refresh a few times while the task is running.
+    driver.refresh()
+    driver.refresh()
+
+    # Get new references after page reloads.
+    counter = driver.find_element(By.ID, "counter")
+    counter_async_cv = driver.find_element(By.ID, "counter-async-cv")
+
+    # Make sure the final total is what we expect.
+    assert background_task._poll_for(lambda: counter.text == "200", timeout=40)
+    assert background_task._poll_for(lambda: counter_async_cv.text == "200", timeout=40)
+    # all tasks should have exited and cleaned up
+    assert background_task._poll_for(
+        lambda: not background_task.app_instance._background_tasks  # pyright: ignore [reportOptionalMemberAccess]
+    )