Bläddra i källkod

Protect StateProxy with an asyncio.Lock (#3508)

* test_background_task: dispatch multiple async tasks

Use asyncio.gather to dispatch multiple tasks from a single background task
that all compete over the `async with self` lock. Even though the state itself
has a lock, each StateProxy instance should only allow a single `async with
self` context to run at a time.

* Protect StateProxy with an asyncio.Lock

Allow multiple tasks to reference the same StateProxy without stomping on each
other when entering an `async with self` context to acquire the state lock and
ultimately modify the state.
Masen Furer 11 månader sedan
förälder
incheckning
af3c9be97c
2 ändrade filer med 28 tillägg och 2 borttagningar
  1. 22 1
      integration/test_background_task.py
  2. 6 1
      reflex/state.py

+ 22 - 1
integration/test_background_task.py

@@ -57,6 +57,20 @@ def BackgroundTask():
         async def non_blocking_pause(self):
             await asyncio.sleep(0.02)
 
+        async def racy_task(self):
+            async with self:
+                self._task_id += 1
+            for _ix in range(int(self.iterations)):
+                async with self:
+                    self.counter += 1
+                await asyncio.sleep(0.005)
+
+        @rx.background
+        async def handle_racy_event(self):
+            await asyncio.gather(
+                self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
+            )
+
     def index() -> rx.Component:
         return rx.vstack(
             rx.chakra.input(
@@ -90,6 +104,11 @@ def BackgroundTask():
                 on_click=State.non_blocking_pause,
                 id="non-blocking-pause",
             ),
+            rx.button(
+                "Racy Increment (x4)",
+                on_click=State.handle_racy_event,
+                id="racy-increment",
+            ),
             rx.button("Reset", on_click=State.reset_counter, id="reset"),
         )
 
@@ -176,6 +195,7 @@ def test_background_task(
     increment_button = driver.find_element(By.ID, "increment")
     blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
     non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
+    racy_increment_button = driver.find_element(By.ID, "racy-increment")
     driver.find_element(By.ID, "reset")
 
     # get a reference to the counter
@@ -196,6 +216,7 @@ def test_background_task(
     delayed_increment_button.click()
     delayed_increment_button.click()
     yield_increment_button.click()
+    racy_increment_button.click()
     non_blocking_pause_button.click()
     yield_increment_button.click()
     blocking_pause_button.click()
@@ -204,7 +225,7 @@ def test_background_task(
         increment_button.click()
     yield_increment_button.click()
     blocking_pause_button.click()
-    assert background_task._poll_for(lambda: counter.text == "420", timeout=40)
+    assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
     # all tasks should have exited and cleaned up
     assert background_task._poll_for(
         lambda: not background_task.app_instance.background_tasks  # type: ignore

+ 6 - 1
reflex/state.py

@@ -1988,6 +1988,7 @@ class StateProxy(wrapt.ObjectProxy):
         self._self_substate_path = state_instance.get_full_name().split(".")
         self._self_actx = None
         self._self_mutable = False
+        self._self_actx_lock = asyncio.Lock()
 
     async def __aenter__(self) -> StateProxy:
         """Enter the async context manager protocol.
@@ -2001,6 +2002,7 @@ class StateProxy(wrapt.ObjectProxy):
         Returns:
             This StateProxy instance in mutable mode.
         """
+        await self._self_actx_lock.acquire()
         self._self_actx = self._self_app.modify_state(
             token=_substate_key(
                 self.__wrapped__.router.session.client_token,
@@ -2025,7 +2027,10 @@ class StateProxy(wrapt.ObjectProxy):
         if self._self_actx is None:
             return
         self._self_mutable = False
-        await self._self_actx.__aexit__(*exc_info)
+        try:
+            await self._self_actx.__aexit__(*exc_info)
+        finally:
+            self._self_actx_lock.release()
         self._self_actx = None
 
     def __enter__(self):