Explorar o código

[HOS-333] Send a "reload" message to the frontend after state expiry (#4442)

* Unit test updates

* test_client_storage: simulate backend state expiry

* [HOS-333] Send a "reload" message to the frontend after state expiry

1. a state instance expires on the backing store
2. frontend attempts to process an event against the expired token and gets a
   fresh instance of the state without router_data set
3. backend sends a "reload" message on the websocket containing the event and
   immediately stops processing
4. in response to the "reload" message, frontend sends
   [hydrate, update client storage, on_load, <previous_event>]

This allows the frontend and backend to re-syncronize on the state of the app
before continuing to process regular events.

If the event in (2) is a special hydrate event, then it is processed normally
by the middleware and the "reload" logic is skipped since this indicates an
initial load or a browser refresh.

* unit tests working with redis
Masen Furer hai 5 meses
pai
achega
39cdce6960

+ 4 - 0
reflex/.templates/web/utils/state.js

@@ -454,6 +454,10 @@ export const connect = async (
       queueEvents(update.events, socket);
     }
   });
+  socket.current.on("reload", async (event) => {
+    event_processing = false;
+    queueEvents([...initialEvents(), JSON5.parse(event)], socket);
+  })
 
   document.addEventListener("visibilitychange", checkVisibility);
 };

+ 16 - 0
reflex/app.py

@@ -73,6 +73,7 @@ from reflex.event import (
     EventSpec,
     EventType,
     IndividualEventType,
+    get_hydrate_event,
     window_alert,
 )
 from reflex.model import Model, get_db_status
@@ -1259,6 +1260,21 @@ async def process(
         )
         # Get the state for the session exclusively.
         async with app.state_manager.modify_state(event.substate_token) as state:
+            # When this is a brand new instance of the state, signal the
+            # frontend to reload before processing it.
+            if (
+                not state.router_data
+                and event.name != get_hydrate_event(state)
+                and app.event_namespace is not None
+            ):
+                await asyncio.create_task(
+                    app.event_namespace.emit(
+                        "reload",
+                        data=format.json_dumps(event),
+                        to=sid,
+                    )
+                )
+                return
             # re-assign only when the value is different
             if state.router_data != router_data:
                 # assignment will recurse into substates and force recalculation of

+ 3 - 0
reflex/state.py

@@ -1959,6 +1959,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 if var in self.base_vars or var in self._backend_vars:
                     self._was_touched = True
                     break
+                if var == constants.ROUTER_DATA and self.parent_state is None:
+                    self._was_touched = True
+                    break
 
     def _get_was_touched(self) -> bool:
         """Check current dirty_vars and flag to determine if state instance was modified.

+ 112 - 1
tests/integration/test_client_storage.py

@@ -10,6 +10,13 @@ from selenium.webdriver import Firefox
 from selenium.webdriver.common.by import By
 from selenium.webdriver.remote.webdriver import WebDriver
 
+from reflex.state import (
+    State,
+    StateManagerDisk,
+    StateManagerMemory,
+    StateManagerRedis,
+    _substate_key,
+)
 from reflex.testing import AppHarness
 
 from . import utils
@@ -74,7 +81,7 @@ def ClientSide():
         return rx.fragment(
             rx.input(
                 value=ClientSideState.router.session.client_token,
-                is_read_only=True,
+                read_only=True,
                 id="token",
             ),
             rx.input(
@@ -604,6 +611,110 @@ async def test_client_side_state(
     assert s2.text == "s2 value"
     assert s3.text == "s3 value"
 
+    # Simulate state expiration
+    if isinstance(client_side.state_manager, StateManagerRedis):
+        await client_side.state_manager.redis.delete(
+            _substate_key(token, State.get_full_name())
+        )
+        await client_side.state_manager.redis.delete(_substate_key(token, state_name))
+        await client_side.state_manager.redis.delete(
+            _substate_key(token, sub_state_name)
+        )
+        await client_side.state_manager.redis.delete(
+            _substate_key(token, sub_sub_state_name)
+        )
+    elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)):
+        del client_side.state_manager.states[token]
+    if isinstance(client_side.state_manager, StateManagerDisk):
+        client_side.state_manager.token_expiration = 0
+        client_side.state_manager._purge_expired_states()
+
+    # Ensure the state is gone (not hydrated)
+    async def poll_for_not_hydrated():
+        state = await client_side.get_state(_substate_key(token or "", state_name))
+        return not state.is_hydrated
+
+    assert await AppHarness._poll_for_async(poll_for_not_hydrated)
+
+    # Trigger event to get a new instance of the state since the old was expired.
+    state_var_input = driver.find_element(By.ID, "state_var")
+    state_var_input.send_keys("re-triggering")
+
+    # get new references to all cookie and local storage elements (again)
+    c1 = driver.find_element(By.ID, "c1")
+    c2 = driver.find_element(By.ID, "c2")
+    c3 = driver.find_element(By.ID, "c3")
+    c4 = driver.find_element(By.ID, "c4")
+    c5 = driver.find_element(By.ID, "c5")
+    c6 = driver.find_element(By.ID, "c6")
+    c7 = driver.find_element(By.ID, "c7")
+    l1 = driver.find_element(By.ID, "l1")
+    l2 = driver.find_element(By.ID, "l2")
+    l3 = driver.find_element(By.ID, "l3")
+    l4 = driver.find_element(By.ID, "l4")
+    s1 = driver.find_element(By.ID, "s1")
+    s2 = driver.find_element(By.ID, "s2")
+    s3 = driver.find_element(By.ID, "s3")
+    c1s = driver.find_element(By.ID, "c1s")
+    l1s = driver.find_element(By.ID, "l1s")
+    s1s = driver.find_element(By.ID, "s1s")
+
+    assert c1.text == "c1 value"
+    assert c2.text == "c2 value"
+    assert c3.text == ""  # temporary cookie expired after reset state!
+    assert c4.text == "c4 value"
+    assert c5.text == "c5 value"
+    assert c6.text == "c6 value"
+    assert c7.text == "c7 value"
+    assert l1.text == "l1 value"
+    assert l2.text == "l2 value"
+    assert l3.text == "l3 value"
+    assert l4.text == "l4 value"
+    assert s1.text == "s1 value"
+    assert s2.text == "s2 value"
+    assert s3.text == "s3 value"
+    assert c1s.text == "c1s value"
+    assert l1s.text == "l1s value"
+    assert s1s.text == "s1s value"
+
+    # Get the backend state and ensure the values are still set
+    async def get_sub_state():
+        root_state = await client_side.get_state(
+            _substate_key(token or "", sub_state_name)
+        )
+        state = root_state.substates[client_side.get_state_name("_client_side_state")]
+        sub_state = state.substates[
+            client_side.get_state_name("_client_side_sub_state")
+        ]
+        return sub_state
+
+    async def poll_for_c1_set():
+        sub_state = await get_sub_state()
+        return sub_state.c1 == "c1 value"
+
+    assert await AppHarness._poll_for_async(poll_for_c1_set)
+    sub_state = await get_sub_state()
+    assert sub_state.c1 == "c1 value"
+    assert sub_state.c2 == "c2 value"
+    assert sub_state.c3 == ""
+    assert sub_state.c4 == "c4 value"
+    assert sub_state.c5 == "c5 value"
+    assert sub_state.c6 == "c6 value"
+    assert sub_state.c7 == "c7 value"
+    assert sub_state.l1 == "l1 value"
+    assert sub_state.l2 == "l2 value"
+    assert sub_state.l3 == "l3 value"
+    assert sub_state.l4 == "l4 value"
+    assert sub_state.s1 == "s1 value"
+    assert sub_state.s2 == "s2 value"
+    assert sub_state.s3 == "s3 value"
+    sub_sub_state = sub_state.substates[
+        client_side.get_state_name("_client_side_sub_sub_state")
+    ]
+    assert sub_sub_state.c1s == "c1s value"
+    assert sub_sub_state.l1s == "l1s value"
+    assert sub_sub_state.s1s == "s1s value"
+
     # clear the cookie jar and local storage, ensure state reset to default
     driver.delete_all_cookies()
     local_storage.clear()

+ 6 - 2
tests/units/test_app.py

@@ -1007,8 +1007,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     substate_token = _substate_key(token, DynamicState)
     sid = "mock_sid"
     client_ip = "127.0.0.1"
-    state = await app.state_manager.get_state(substate_token)
-    assert state.dynamic == ""
+    async with app.state_manager.modify_state(substate_token) as state:
+        state.router_data = {"simulate": "hydrated"}
+        assert state.dynamic == ""
     exp_vals = ["foo", "foobar", "baz"]
 
     def _event(name, val, **kwargs):
@@ -1180,6 +1181,7 @@ async def test_process_events(mocker, token: str):
         "ip": "127.0.0.1",
     }
     app = App(state=GenState)
+
     mocker.patch.object(app, "_postprocess", AsyncMock())
     event = Event(
         token=token,
@@ -1187,6 +1189,8 @@ async def test_process_events(mocker, token: str):
         payload={"c": 5},
         router_data=router_data,
     )
+    async with app.state_manager.modify_state(event.substate_token) as state:
+        state.router_data = {"simulate": "hydrated"}
 
     async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
         pass

+ 9 - 3
tests/units/test_state.py

@@ -1982,6 +1982,10 @@ class BackgroundTaskState(BaseState):
     order: List[str] = []
     dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
 
+    def __init__(self, **kwargs):  # noqa: D107
+        super().__init__(**kwargs)
+        self.router_data = {"simulate": "hydrate"}
+
     @rx.var
     def computed_order(self) -> List[str]:
         """Get the order as a computed var.
@@ -2732,7 +2736,7 @@ def test_set_base_field_via_setter():
     assert "c2" in bfss.dirty_vars
 
 
-def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
+def exp_is_hydrated(state: BaseState, is_hydrated: bool = True) -> Dict[str, Any]:
     """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
 
     Args:
@@ -2811,7 +2815,8 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
     app = app_module_mock.app = App(
         state=State, load_events={"index": [test_state.test_handler]}
     )
-    state = State()
+    async with app.state_manager.modify_state(_substate_key(token, State)) as state:
+        state.router_data = {"simulate": "hydrate"}
 
     updates = []
     async for update in rx.app.process(
@@ -2858,7 +2863,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
         state=State,
         load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
     )
-    state = State()
+    async with app.state_manager.modify_state(_substate_key(token, State)) as state:
+        state.router_data = {"simulate": "hydrate"}
 
     updates = []
     async for update in rx.app.process(