|
@@ -12,6 +12,7 @@ import inspect
|
|
|
import io
|
|
|
import json
|
|
|
import sys
|
|
|
+import time
|
|
|
import traceback
|
|
|
from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
|
|
|
from datetime import datetime
|
|
@@ -1520,12 +1521,32 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
yield state
|
|
|
delta = await state._get_resolved_delta()
|
|
|
if delta:
|
|
|
- # When the state is modified reset dirty status and emit the delta to the frontend.
|
|
|
- state._clean()
|
|
|
- await self.event_namespace.emit_update(
|
|
|
- update=StateUpdate(delta=delta),
|
|
|
- sid=state.router.session.session_id,
|
|
|
+ synthetic_modify_event = Event(
|
|
|
+ token=state.router.session.client_token,
|
|
|
+ name="App.modify_state",
|
|
|
)
|
|
|
+ if not (
|
|
|
+ updated_sid := await self.event_namespace._get_sid_with_timeout(
|
|
|
+ event=synthetic_modify_event, warn_on_error=True
|
|
|
+ )
|
|
|
+ ):
|
|
|
+ # The client has disconnected, so we don't need to emit the update.
|
|
|
+ return
|
|
|
+ update = StateUpdate(delta=delta)
|
|
|
+ await self.event_namespace.emit_update(update=update, sid=updated_sid)
|
|
|
+ if not self.event_namespace.token_to_sid.get(synthetic_modify_event.token) and (
|
|
|
+ updated_sid := await self.event_namespace._get_sid_with_timeout(
|
|
|
+ event=synthetic_modify_event, warn_on_error=True
|
|
|
+ )
|
|
|
+ ):
|
|
|
+ #breakpoint()
|
|
|
+ # The client reconnected after failing to send, so try to send it again.
|
|
|
+ await self.event_namespace.emit_update(
|
|
|
+ update=update, sid=updated_sid
|
|
|
+ )
|
|
|
+ if updated_sid:
|
|
|
+ # Reset dirty status after the delta is sent to the frontend.
|
|
|
+ state._clean()
|
|
|
|
|
|
def _process_background(
|
|
|
self, state: BaseState, event: Event
|
|
@@ -1544,6 +1565,27 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
if not handler.is_background:
|
|
|
return None
|
|
|
|
|
|
+ started_at = datetime.now()
|
|
|
+
|
|
|
+ async def _get_sid_or_bail(
|
|
|
+ timeout: int | float = environment.REFLEX_CLIENT_DISCONNECT_TIMEOUT.get(),
|
|
|
+ ) -> str | None:
|
|
|
+ if self.event_namespace is None:
|
|
|
+ raise RuntimeError("App has not been initialized yet.")
|
|
|
+ if not (
|
|
|
+ sid := await self.event_namespace._get_sid_with_timeout(
|
|
|
+ event, timeout, warn_on_error=False
|
|
|
+ )
|
|
|
+ ):
|
|
|
+ # If the sid is not found, skip sending the update and cancel task.
|
|
|
+ console.warn(
|
|
|
+ f"Cannot send update from background task {handler.fn} (started at {started_at.isoformat()}): "
|
|
|
+ f"No websocket associated with token {event.token} after {timeout}s.",
|
|
|
+ dedupe=True,
|
|
|
+ )
|
|
|
+ task.cancel()
|
|
|
+ return sid
|
|
|
+
|
|
|
async def _coro():
|
|
|
"""Coroutine to process the event and emit updates inside an asyncio.Task.
|
|
|
|
|
@@ -1560,16 +1602,37 @@ class App(MiddlewareMixin, LifespanMixin):
|
|
|
# Postprocess the event.
|
|
|
update = await self._postprocess(state, event, update)
|
|
|
|
|
|
+ if not (sid := await _get_sid_or_bail()):
|
|
|
+ return
|
|
|
+
|
|
|
# Send the update to the client.
|
|
|
- await self.event_namespace.emit_update(
|
|
|
- update=update,
|
|
|
- sid=state.router.session.session_id,
|
|
|
- )
|
|
|
+ await self.event_namespace.emit_update(update=update, sid=sid)
|
|
|
+
|
|
|
+ # Check if the client is still connected and attempt to resend if it wasn't
|
|
|
+ if not self.event_namespace.token_to_sid.get(event.token) and (
|
|
|
+ sid := await _get_sid_or_bail()
|
|
|
+ ):
|
|
|
+ #breakpoint()
|
|
|
+ print(f"Last attempt failed, trying again {update}")
|
|
|
+ # Second try to send the update.
|
|
|
+ await self.event_namespace.emit_update(update=update, sid=sid)
|
|
|
+
|
|
|
+ def finish(t: asyncio.Task[None]):
|
|
|
+ """Finish the task and remove it from the background tasks.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ t: The task to finish.
|
|
|
+ """
|
|
|
+ # Remove the task from the background tasks.
|
|
|
+ self._background_tasks.discard(t)
|
|
|
+
|
|
|
+ # Raise for errors.
|
|
|
+ t.result()
|
|
|
|
|
|
task = asyncio.create_task(_coro())
|
|
|
self._background_tasks.add(task)
|
|
|
# Clean up task from background_tasks set when complete.
|
|
|
- task.add_done_callback(self._background_tasks.discard)
|
|
|
+ task.add_done_callback(finish)
|
|
|
return task
|
|
|
|
|
|
def _validate_exception_handlers(self):
|
|
@@ -1975,6 +2038,7 @@ class EventNamespace(AsyncNamespace):
|
|
|
"""
|
|
|
disconnect_token = self.sid_to_token.pop(sid, None)
|
|
|
if disconnect_token:
|
|
|
+ print(f"Disconnect {disconnect_token}")
|
|
|
self.token_to_sid.pop(disconnect_token, None)
|
|
|
|
|
|
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
|
@@ -1985,10 +2049,32 @@ class EventNamespace(AsyncNamespace):
|
|
|
sid: The Socket.IO session id.
|
|
|
"""
|
|
|
# Creating a task prevents the update from being blocked behind other coroutines.
|
|
|
+ print(f"emit_update: {update=} {sid=}")
|
|
|
await asyncio.create_task(
|
|
|
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
|
|
|
)
|
|
|
|
|
|
+ async def _get_sid_with_timeout(
|
|
|
+ self,
|
|
|
+ event: Event,
|
|
|
+ timeout: int | float | None = None,
|
|
|
+ warn_on_error: bool = False,
|
|
|
+ ) -> str | None:
|
|
|
+ """Wait up to timeout seconds for the sid to be available, otherwise cancel the task."""
|
|
|
+ if timeout is None:
|
|
|
+ timeout = environment.REFLEX_CLIENT_DISCONNECT_TIMEOUT.get()
|
|
|
+ deadline = time.time() + timeout
|
|
|
+
|
|
|
+ # Find the latest sid for the client_token.
|
|
|
+ while not (sid := self.token_to_sid.get(event.token)) and time.time() < deadline:
|
|
|
+ await asyncio.sleep(0.5)
|
|
|
+ if not sid and warn_on_error:
|
|
|
+ console.warn(
|
|
|
+ f"Cannot send update from event {event.name}: "
|
|
|
+ f"No websocket associated with token {event.token} after {timeout}s."
|
|
|
+ )
|
|
|
+ return sid
|
|
|
+
|
|
|
async def on_event(self, sid: str, data: Any):
|
|
|
"""Event for receiving front-end websocket events.
|
|
|
|
|
@@ -2053,8 +2139,23 @@ class EventNamespace(AsyncNamespace):
|
|
|
|
|
|
# Process the events.
|
|
|
async for update in process(self.app, event, sid, headers, client_ip):
|
|
|
+ if not (
|
|
|
+ updated_sid := await self._get_sid_with_timeout(
|
|
|
+ event, warn_on_error=True
|
|
|
+ )
|
|
|
+ ):
|
|
|
+ # The client has disconnected, so we don't need to emit the update.
|
|
|
+ return
|
|
|
# Emit the update from processing the event.
|
|
|
- await self.emit_update(update=update, sid=sid)
|
|
|
+ await self.emit_update(update=update, sid=updated_sid)
|
|
|
+ if not self.token_to_sid.get(event.token) and (
|
|
|
+ updated_sid := await self._get_sid_with_timeout(
|
|
|
+ event=event, warn_on_error=True
|
|
|
+ )
|
|
|
+ ):
|
|
|
+ #breakpoint()
|
|
|
+ # The client reconnected after failing to send, so try to send it again.
|
|
|
+ await self.emit_update(update=update, sid=updated_sid)
|
|
|
|
|
|
async def on_ping(self, sid: str):
|
|
|
"""Event for testing the API endpoint.
|