Browse Source

[REF-3184] [REF-3339] Background task locking improvements (#3696)

* [REF-3184] Raise exception when encountering nested `async with self` blocks

Avoid deadlock when the background task already holds the mutation lock for a
given state.

* [REF-3339] get_state from background task links to StateProxy

When calling `get_state` from a background task, the resulting state instance
is wrapped in a StateProxy that is bound to the original StateProxy and shares
the same async context, lock, and mutability flag.

* If StateProxy has a _self_parent_state_proxy, retrieve the correct substate

* test_state fixup
Masen Furer 10 months ago
parent
commit
0845d2ee76
3 changed files with 161 additions and 10 deletions
  1. 103 0
      integration/test_background_task.py
  2. 57 9
      reflex/state.py
  3. 1 1
      tests/test_state.py

+ 103 - 0
integration/test_background_task.py

@@ -12,7 +12,10 @@ def BackgroundTask():
     """Test that background tasks work as expected."""
     import asyncio
 
+    import pytest
+
     import reflex as rx
+    from reflex.state import ImmutableStateError
 
     class State(rx.State):
         counter: int = 0
@@ -71,6 +74,38 @@ def BackgroundTask():
                 self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
             )
 
+        @rx.background
+        async def nested_async_with_self(self):
+            async with self:
+                self.counter += 1
+                with pytest.raises(ImmutableStateError):
+                    async with self:
+                        self.counter += 1
+
+        async def triple_count(self):
+            third_state = await self.get_state(ThirdState)
+            await third_state._triple_count()
+
+    class OtherState(rx.State):
+        @rx.background
+        async def get_other_state(self):
+            async with self:
+                state = await self.get_state(State)
+                state.counter += 1
+                await state.triple_count()
+            with pytest.raises(ImmutableStateError):
+                await state.triple_count()
+            with pytest.raises(ImmutableStateError):
+                state.counter += 1
+            async with state:
+                state.counter += 1
+                await state.triple_count()
+
+    class ThirdState(rx.State):
+        async def _triple_count(self):
+            state = await self.get_state(State)
+            state.counter *= 3
+
     def index() -> rx.Component:
         return rx.vstack(
             rx.chakra.input(
@@ -109,6 +144,16 @@ def BackgroundTask():
                 on_click=State.handle_racy_event,
                 id="racy-increment",
             ),
+            rx.button(
+                "Nested Async with Self",
+                on_click=State.nested_async_with_self,
+                id="nested-async-with-self",
+            ),
+            rx.button(
+                "Increment from OtherState",
+                on_click=OtherState.get_other_state,
+                id="increment-from-other-state",
+            ),
             rx.button("Reset", on_click=State.reset_counter, id="reset"),
         )
 
@@ -230,3 +275,61 @@ def test_background_task(
     assert background_task._poll_for(
         lambda: not background_task.app_instance.background_tasks  # type: ignore
     )
+
+
+def test_nested_async_with_self(
+    background_task: AppHarness,
+    driver: WebDriver,
+    token: str,
+):
+    """Test that nested async with self in the same coroutine raises Exception.
+
+    Args:
+        background_task: harness for BackgroundTask app.
+        driver: WebDriver instance.
+        token: The token for the connected client.
+    """
+    assert background_task.app_instance is not None
+
+    # get a reference to all buttons
+    nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
+    increment_button = driver.find_element(By.ID, "increment")
+
+    # get a reference to the counter
+    counter = driver.find_element(By.ID, "counter")
+    assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
+
+    nested_async_with_self_button.click()
+    assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
+
+    increment_button.click()
+    assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
+
+
+def test_get_state(
+    background_task: AppHarness,
+    driver: WebDriver,
+    token: str,
+):
+    """Test that get_state returns a state bound to the correct StateProxy.
+
+    Args:
+        background_task: harness for BackgroundTask app.
+        driver: WebDriver instance.
+        token: The token for the connected client.
+    """
+    assert background_task.app_instance is not None
+
+    # get a reference to all buttons
+    other_state_button = driver.find_element(By.ID, "increment-from-other-state")
+    increment_button = driver.find_element(By.ID, "increment")
+
+    # get a reference to the counter
+    counter = driver.find_element(By.ID, "counter")
+    assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
+
+    other_state_button.click()
+    assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
+
+    increment_button.click()
+    assert background_task._poll_for(lambda: counter.text == "13", timeout=5)

+ 57 - 9
reflex/state.py

@@ -202,7 +202,7 @@ def _no_chain_background_task(
 
 def _substate_key(
     token: str,
-    state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
+    state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str],
 ) -> str:
     """Get the substate key.
 
@@ -2029,19 +2029,38 @@ class StateProxy(wrapt.ObjectProxy):
                     self.counter += 1
     """
 
-    def __init__(self, state_instance):
+    def __init__(
+        self, state_instance, parent_state_proxy: Optional["StateProxy"] = None
+    ):
         """Create a proxy for a state instance.
 
+        If `get_state` is used on a StateProxy, the resulting state will be
+        linked to the given state via parent_state_proxy. The first state in the
+        chain is the state that initiated the background task.
+
         Args:
             state_instance: The state instance to proxy.
+            parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
         """
         super().__init__(state_instance)
         # compile is not relevant to backend logic
         self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
-        self._self_substate_path = state_instance.get_full_name().split(".")
+        self._self_substate_path = tuple(state_instance.get_full_name().split("."))
         self._self_actx = None
         self._self_mutable = False
         self._self_actx_lock = asyncio.Lock()
+        self._self_actx_lock_holder = None
+        self._self_parent_state_proxy = parent_state_proxy
+
+    def _is_mutable(self) -> bool:
+        """Check if the state is mutable.
+
+        Returns:
+            Whether the state is mutable.
+        """
+        if self._self_parent_state_proxy is not None:
+            return self._self_parent_state_proxy._is_mutable()
+        return self._self_mutable
 
     async def __aenter__(self) -> StateProxy:
         """Enter the async context manager protocol.
@@ -2054,8 +2073,31 @@ class StateProxy(wrapt.ObjectProxy):
 
         Returns:
             This StateProxy instance in mutable mode.
-        """
+
+        Raises:
+            ImmutableStateError: If the state is already mutable.
+        """
+        if self._self_parent_state_proxy is not None:
+            parent_state = (
+                await self._self_parent_state_proxy.__aenter__()
+            ).__wrapped__
+            super().__setattr__(
+                "__wrapped__",
+                await parent_state.get_state(
+                    State.get_class_substate(self._self_substate_path)
+                ),
+            )
+            return self
+        current_task = asyncio.current_task()
+        if (
+            self._self_actx_lock.locked()
+            and current_task == self._self_actx_lock_holder
+        ):
+            raise ImmutableStateError(
+                "The state is already mutable. Do not nest `async with self` blocks."
+            )
         await self._self_actx_lock.acquire()
+        self._self_actx_lock_holder = current_task
         self._self_actx = self._self_app.modify_state(
             token=_substate_key(
                 self.__wrapped__.router.session.client_token,
@@ -2077,12 +2119,16 @@ class StateProxy(wrapt.ObjectProxy):
         Args:
             exc_info: The exception info tuple.
         """
+        if self._self_parent_state_proxy is not None:
+            await self._self_parent_state_proxy.__aexit__(*exc_info)
+            return
         if self._self_actx is None:
             return
         self._self_mutable = False
         try:
             await self._self_actx.__aexit__(*exc_info)
         finally:
+            self._self_actx_lock_holder = None
             self._self_actx_lock.release()
         self._self_actx = None
 
@@ -2117,7 +2163,7 @@ class StateProxy(wrapt.ObjectProxy):
         Raises:
             ImmutableStateError: If the state is not in mutable mode.
         """
-        if name in ["substates", "parent_state"] and not self._self_mutable:
+        if name in ["substates", "parent_state"] and not self._is_mutable():
             raise ImmutableStateError(
                 "Background task StateProxy is immutable outside of a context "
                 "manager. Use `async with self` to modify state."
@@ -2157,7 +2203,7 @@ class StateProxy(wrapt.ObjectProxy):
         """
         if (
             name.startswith("_self_")  # wrapper attribute
-            or self._self_mutable  # lock held
+            or self._is_mutable()  # lock held
             # non-persisted state attribute
             or name in self.__wrapped__.get_skip_vars()
         ):
@@ -2181,7 +2227,7 @@ class StateProxy(wrapt.ObjectProxy):
         Raises:
             ImmutableStateError: If the state is not in mutable mode.
         """
-        if not self._self_mutable:
+        if not self._is_mutable():
             raise ImmutableStateError(
                 "Background task StateProxy is immutable outside of a context "
                 "manager. Use `async with self` to modify state."
@@ -2200,12 +2246,14 @@ class StateProxy(wrapt.ObjectProxy):
         Raises:
             ImmutableStateError: If the state is not in mutable mode.
         """
-        if not self._self_mutable:
+        if not self._is_mutable():
             raise ImmutableStateError(
                 "Background task StateProxy is immutable outside of a context "
                 "manager. Use `async with self` to modify state."
             )
-        return await self.__wrapped__.get_state(state_cls)
+        return type(self)(
+            await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
+        )
 
     def _as_state_update(self, *args, **kwargs) -> StateUpdate:
         """Temporarily allow mutability to access parent_state.

+ 1 - 1
tests/test_state.py

@@ -1825,7 +1825,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
 
     sp = StateProxy(grandchild_state)
     assert sp.__wrapped__ == grandchild_state
-    assert sp._self_substate_path == grandchild_state.get_full_name().split(".")
+    assert sp._self_substate_path == tuple(grandchild_state.get_full_name().split("."))
     assert sp._self_app is mock_app
     assert not sp._self_mutable
     assert sp._self_actx is None