Sfoglia il codice sorgente

Resolve async computed vars in background task (#5057)

When using `modify_state` (background task), the async computed vars were being
left as coroutine objects instead of awaiting them with _get_resolved_delta
Masen Furer 2 mesi fa
parent
commit
be63af5f5b
3 ha cambiato i file con 25 aggiunte e 3 eliminazioni
  1. 1 1
      reflex/app.py
  2. 9 1
      reflex/state.py
  3. 15 1
      tests/integration/test_background_task.py

+ 1 - 1
reflex/app.py

@@ -1427,7 +1427,7 @@ class App(MiddlewareMixin, LifespanMixin):
         async with self.state_manager.modify_state(token) as state:
             # No other event handler can modify the state while in this context.
             yield state
-            delta = state.get_delta()
+            delta = await state._get_resolved_delta()
             if delta:
                 # When the state is modified reset dirty status and emit the delta to the frontend.
                 state._clean()

+ 9 - 1
reflex/state.py

@@ -1702,7 +1702,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         try:
             # Get the delta after processing the event.
-            delta = await _resolve_delta(state.get_delta())
+            delta = await state._get_resolved_delta()
             state._clean()
 
             return StateUpdate(
@@ -1947,6 +1947,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Return the delta.
         return delta
 
+    async def _get_resolved_delta(self) -> Delta:
+        """Get the delta for the state after resolving all coroutines.
+
+        Returns:
+            The resolved delta for the state.
+        """
+        return await _resolve_delta(self.get_delta())
+
     def _mark_dirty(self):
         """Mark the substate and all parent states as dirty."""
         state_name = self.get_name()

+ 15 - 1
tests/integration/test_background_task.py

@@ -26,6 +26,15 @@ def BackgroundTask():
         def set_iterations(self, value: str):
             self.iterations = int(value)
 
+        @rx.var
+        async def counter_async_cv(self) -> int:
+            """This exists solely as an integration test for background tasks triggering async var updates.
+
+            Returns:
+                The current value of the counter.
+            """
+            return self.counter
+
         @rx.event(background=True)
         async def handle_event(self):
             async with self:
@@ -125,7 +134,10 @@ def BackgroundTask():
             rx.input(
                 id="token", value=State.router.session.client_token, is_read_only=True
             ),
-            rx.heading(State.counter, id="counter"),
+            rx.hstack(
+                rx.heading(State.counter, id="counter"),
+                rx.text(State.counter_async_cv, size="1", id="counter-async-cv"),
+            ),
             rx.input(
                 id="iterations",
                 placeholder="Iterations",
@@ -264,6 +276,7 @@ def test_background_task(
 
     # get a reference to the counter
     counter = driver.find_element(By.ID, "counter")
+    counter_async_cv = driver.find_element(By.ID, "counter-async-cv")
 
     # get a reference to the iterations input
     iterations_input = driver.find_element(By.ID, "iterations")
@@ -290,6 +303,7 @@ def test_background_task(
     yield_increment_button.click()
     blocking_pause_button.click()
     assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
+    assert background_task._poll_for(lambda: counter_async_cv.text == "620", timeout=40)
     # all tasks should have exited and cleaned up
     assert background_task._poll_for(
         lambda: not background_task.app_instance._background_tasks  # pyright: ignore [reportOptionalMemberAccess]