1
0
Эх сурвалжийг харах

Generate state delta from processed state instance (#2023)

Masen Furer 1 жил өмнө
parent
commit
1734ba0b6d
2 өөрчлөгдсөн 86 нэмэгдсэн , 17 устгасан
  1. 27 15
      reflex/state.py
  2. 59 2
      tests/test_state.py

+ 27 - 15
reflex/state.py

@@ -963,14 +963,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The valid StateUpdate containing the events and final flag.
         """
+        # get the delta from the root of the state tree
+        state = self
+        while state.parent_state is not None:
+            state = state.parent_state
+
         token = self.router.session.client_token
 
         # Convert valid EventHandler and EventSpec into Event
         fixed_events = fix_events(self._check_valid(handler, events), token)
 
         # Get the delta after processing the event.
-        delta = self.get_delta()
-        self._clean()
+        delta = state.get_delta()
+        state._clean()
 
         return StateUpdate(
             delta=delta,
@@ -1009,30 +1014,30 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             # Handle async generators.
             if inspect.isasyncgen(events):
                 async for event in events:
-                    yield self._as_state_update(handler, event, final=False)
-                yield self._as_state_update(handler, events=None, final=True)
+                    yield state._as_state_update(handler, event, final=False)
+                yield state._as_state_update(handler, events=None, final=True)
 
             # Handle regular generators.
             elif inspect.isgenerator(events):
                 try:
                     while True:
-                        yield self._as_state_update(handler, next(events), final=False)
+                        yield state._as_state_update(handler, next(events), final=False)
                 except StopIteration as si:
                     # the "return" value of the generator is not available
                     # in the loop, we must catch StopIteration to access it
                     if si.value is not None:
-                        yield self._as_state_update(handler, si.value, final=False)
-                yield self._as_state_update(handler, events=None, final=True)
+                        yield state._as_state_update(handler, si.value, final=False)
+                yield state._as_state_update(handler, events=None, final=True)
 
             # Handle regular event chains.
             else:
-                yield self._as_state_update(handler, events, final=True)
+                yield state._as_state_update(handler, events, final=True)
 
         # If an error occurs, throw a window alert.
         except Exception:
             error = traceback.format_exc()
             print(error)
-            yield self._as_state_update(
+            yield state._as_state_update(
                 handler,
                 window_alert("An error occurred. See logs for details."),
                 final=True,
@@ -1360,12 +1365,19 @@ class StateProxy(wrapt.ObjectProxy):
         Raises:
             ImmutableStateError: If the state is not in mutable mode.
         """
-        if not name.startswith("_self_") and not self._self_mutable:
-            raise ImmutableStateError(
-                "Background task StateProxy is immutable outside of a context "
-                "manager. Use `async with self` to modify state."
-            )
-        super().__setattr__(name, value)
+        if (
+            name.startswith("_self_")  # wrapper attribute
+            or self._self_mutable  # lock held
+            # non-persisted state attribute
+            or name in self.__wrapped__.get_skip_vars()
+        ):
+            super().__setattr__(name, value)
+            return
+
+        raise ImmutableStateError(
+            "Background task StateProxy is immutable outside of a context "
+            "manager. Use `async with self` to modify state."
+        )
 
 
 class StateUpdate(Base):

+ 59 - 2
tests/test_state.py

@@ -1577,7 +1577,7 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
 
     setattr(app_module, CompileVars.APP, app)
     app.state = TestState
-    app.state_manager = state_manager
+    app._state_manager = state_manager
     app.event_namespace.emit = AsyncMock()  # type: ignore
     monkeypatch.setattr(prerequisites, "get_app", lambda: app_module)
     return app
@@ -1663,6 +1663,15 @@ class BackgroundTaskState(State):
     order: List[str] = []
     dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
 
+    @rx.var
+    def computed_order(self) -> List[str]:
+        """Get the order as a computed var.
+
+        Returns:
+            The value of 'order' var.
+        """
+        return self.order
+
     @rx.background
     async def background_task(self):
         """A background task that updates the state."""
@@ -1791,6 +1800,10 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
                         "background_task:start",
                         "other",
                     ],
+                    "computed_order": [
+                        "background_task:start",
+                        "other",
+                    ],
                 }
             }
         )
@@ -1800,7 +1813,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         await task
     assert not mock_app.background_tasks
 
-    assert (await mock_app.state_manager.get_state(token)).order == [
+    exp_order = [
         "background_task:start",
         "other",
         "background_task:stop",
@@ -1808,6 +1821,50 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
         "private",
     ]
 
+    assert (await mock_app.state_manager.get_state(token)).order == exp_order
+
+    assert mock_app.event_namespace is not None
+    emit_mock = mock_app.event_namespace.emit
+
+    assert json.loads(emit_mock.mock_calls[0].args[1]) == {
+        "delta": {
+            "background_task_state": {
+                "order": ["background_task:start"],
+                "computed_order": ["background_task:start"],
+            }
+        },
+        "events": [],
+        "final": True,
+    }
+    for call in emit_mock.mock_calls[1:5]:
+        assert json.loads(call.args[1]) == {
+            "delta": {
+                "background_task_state": {"computed_order": ["background_task:start"]}
+            },
+            "events": [],
+            "final": True,
+        }
+    assert json.loads(emit_mock.mock_calls[-2].args[1]) == {
+        "delta": {
+            "background_task_state": {
+                "order": exp_order,
+                "computed_order": exp_order,
+                "dict_list": {},
+            }
+        },
+        "events": [],
+        "final": True,
+    }
+    assert json.loads(emit_mock.mock_calls[-1].args[1]) == {
+        "delta": {
+            "background_task_state": {
+                "computed_order": exp_order,
+            },
+        },
+        "events": [],
+        "final": True,
+    }
+
 
 @pytest.mark.asyncio
 async def test_background_task_reset(mock_app: rx.App, token: str):