Преглед на файлове

state.js: set event_processing = false when websocket connects (#3443)

* test_connection_banner: Improve assertions

* Actually assert on the presence or absense of the connection banner
* Update XPATH selector to find the connection toast
* Add event handling while backend down to verify queue functionality
* Bring backend down while an event is running to ensure queue does not get
  blocked (#3404)

* state.js: set event_processing = false when websocket connects

In case an event was pending when the websocket went down, allow further events
to be processed when it comes back up.

Fix #3404

* test_connection_banner: wait for token indicating backend is connected

* test_connection_banner: increase delay time

make the time window longer in which the backend can go down and get stuck in
event_processing=true for better test reliability

* Ensure the redis connection is reset in new backend thread

Redis has an event loop affinity and needs to be attached to the event loop
that the thread is running.

* Reset event_processing on disconnect

* if the socket never comes back up, it still allows client-side events to be
  processed
* on_mount events may start running before the socket is up, so resetting the
  flag on connect may break event determinism (test_event_chain.py)
Masen Furer преди 11 месеца
родител
ревизия
e138d9dfd0
променени са 3 файла, в които са добавени 76 реда и са изтрити 6 реда
  1. 49 6
      integration/test_connection_banner.py
  2. 6 0
      reflex/.templates/web/utils/state.js
  3. 21 0
      reflex/testing.py

+ 49 - 6
integration/test_connection_banner.py

@@ -8,16 +8,32 @@ from selenium.webdriver.common.by import By
 
 from reflex.testing import AppHarness, WebDriver
 
+from .utils import SessionStorage
+
 
 def ConnectionBanner():
     """App with a connection banner."""
+    import asyncio
+
     import reflex as rx
 
     class State(rx.State):
         foo: int = 0
 
+        async def delay(self):
+            await asyncio.sleep(5)
+
     def index():
-        return rx.text("Hello World")
+        return rx.vstack(
+            rx.text("Hello World"),
+            rx.input(value=State.foo, read_only=True, id="counter"),
+            rx.button(
+                "Increment",
+                id="increment",
+                on_click=State.set_foo(State.foo + 1),  # type: ignore
+            ),
+            rx.button("Delay", id="delay", on_click=State.delay),
+        )
 
     app = rx.App(state=rx.State)
     app.add_page(index)
@@ -40,7 +56,7 @@ def connection_banner(tmp_path) -> Generator[AppHarness, None, None]:
         yield harness
 
 
-CONNECTION_ERROR_XPATH = "//*[ text() = 'Connection Error' ]"
+CONNECTION_ERROR_XPATH = "//*[ contains(text(), 'Cannot connect to server') ]"
 
 
 def has_error_modal(driver: WebDriver) -> bool:
@@ -59,7 +75,8 @@ def has_error_modal(driver: WebDriver) -> bool:
         return False
 
 
-def test_connection_banner(connection_banner: AppHarness):
+@pytest.mark.asyncio
+async def test_connection_banner(connection_banner: AppHarness):
     """Test that the connection banner is displayed when the websocket drops.
 
     Args:
@@ -69,7 +86,23 @@ def test_connection_banner(connection_banner: AppHarness):
     assert connection_banner.backend is not None
     driver = connection_banner.frontend()
 
-    connection_banner._poll_for(lambda: not has_error_modal(driver))
+    ss = SessionStorage(driver)
+    assert connection_banner._poll_for(
+        lambda: ss.get("token") is not None
+    ), "token not found"
+
+    assert connection_banner._poll_for(lambda: not has_error_modal(driver))
+
+    delay_button = driver.find_element(By.ID, "delay")
+    increment_button = driver.find_element(By.ID, "increment")
+    counter_element = driver.find_element(By.ID, "counter")
+
+    # Increment the counter
+    increment_button.click()
+    assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
+
+    # Start an long event before killing the backend, to mark event_processing=true
+    delay_button.click()
 
     # Get the backend port
     backend_port = connection_banner._poll_for_servers().getsockname()[1]
@@ -80,10 +113,20 @@ def test_connection_banner(connection_banner: AppHarness):
         connection_banner.backend_thread.join()
 
     # Error modal should now be displayed
-    connection_banner._poll_for(lambda: has_error_modal(driver))
+    assert connection_banner._poll_for(lambda: has_error_modal(driver))
+
+    # Increment the counter with backend down
+    increment_button.click()
+    assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
 
     # Bring the backend back up
     connection_banner._start_backend(port=backend_port)
 
+    # Create a new StateManager to avoid async loop affinity issues w/ redis
+    await connection_banner._reset_backend_state_manager()
+
     # Banner should be gone now
-    connection_banner._poll_for(lambda: not has_error_modal(driver))
+    assert connection_banner._poll_for(lambda: not has_error_modal(driver))
+
+    # Count should have incremented after coming back up
+    assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"

+ 6 - 0
reflex/.templates/web/utils/state.js

@@ -358,6 +358,12 @@ export const connect = async (
   socket.current.on("connect_error", (error) => {
     setConnectErrors((connectErrors) => [connectErrors.slice(-9), error]);
   });
+
+  // When the socket disconnects reset the event_processing flag
+  socket.current.on("disconnect", () => {
+    event_processing = false;
+  });
+
   // On each received message, queue the updates and events.
   socket.current.on("event", (message) => {
     const update = JSON5.parse(message);

+ 21 - 0
reflex/testing.py

@@ -297,6 +297,27 @@ class AppHarness:
         self.backend_thread = threading.Thread(target=self.backend.run)
         self.backend_thread.start()
 
+    async def _reset_backend_state_manager(self):
+        """Reset the StateManagerRedis event loop affinity.
+
+        This is necessary when the backend is restarted and the state manager is a
+        StateManagerRedis instance.
+        """
+        if (
+            self.app_instance is not None
+            and isinstance(
+                self.app_instance.state_manager,
+                StateManagerRedis,
+            )
+            and self.app_instance.state is not None
+        ):
+            with contextlib.suppress(RuntimeError):
+                await self.app_instance.state_manager.close()
+            self.app_instance._state_manager = StateManagerRedis.create(
+                state=self.app_instance.state,
+            )
+            assert isinstance(self.app_instance.state_manager, StateManagerRedis)
+
     def _start_frontend(self):
         # Set up the frontend.
         with chdir(self.app_path):