瀏覽代碼

[REF-2219] Avoid refetching states that are already cached (#2953)

* Add test_get_state_from_sibling_not_cached

A better unit test to catch issues with refetching parent states
and calculating the wrong parent state names to fetch.

* _determine_missing_parent_states: correctly generate state names

Prepend only the previous state name to the current relative_parent_state_name
instead of joining all of the previous state names together.

* [REF-2219] Avoid refetching states that are already cached

The already cached states may have unsaved changes which can be wiped out if
they are refetched from redis in the middle of handling an event.

If the root state already knows about one of the potentially missing states,
then use the instance that is already cached.

Fix #2851
Masen Furer 1 年之前
父節點
當前提交
55b0fb36e8
共有 2 個文件被更改,包括 105 次插入3 次删除
  1. 12 3
      reflex/state.py
  2. 93 0
      tests/test_state.py

+ 12 - 3
reflex/state.py

@@ -1232,9 +1232,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
         fetch_parent_states = [common_ancestor_name]
-        for ix, relative_parent_state_name in enumerate(relative_target_state_parts):
+        for relative_parent_state_name in relative_target_state_parts:
             fetch_parent_states.append(
-                ".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name])
+                ".".join((fetch_parent_states[-1], relative_parent_state_name))
             )
 
         return common_ancestor_name, fetch_parent_states[1:-1]
@@ -1278,9 +1278,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         ) = self._determine_missing_parent_states(target_state_cls)
 
         # Fetch all missing parent states and link them up to the common ancestor.
-        parent_states_by_name = dict(self._get_parent_states())
+        parent_states_tuple = self._get_parent_states()
+        root_state = parent_states_tuple[-1][1]
+        parent_states_by_name = dict(parent_states_tuple)
         parent_state = parent_states_by_name[common_ancestor_name]
         for parent_state_name in missing_parent_states:
+            try:
+                parent_state = root_state.get_substate(parent_state_name.split("."))
+                # The requested state is already cached, do NOT fetch it again.
+                continue
+            except ValueError:
+                # The requested state is missing, fetch from redis.
+                pass
             parent_state = await state_manager.get_state(
                 token=_substate_key(
                     self.router.session.client_token, parent_state_name

+ 93 - 0
tests/test_state.py

@@ -2729,6 +2729,99 @@ async def test_get_state(mock_app: rx.App, token: str):
     }
 
 
+@pytest.mark.asyncio
+async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
+    """A test simulating update_vars_internal when setting cookies with computed vars.
+
+    In that case, a sibling state, UpdateVarsInternalState handles the fetching
+    of states that need to have values set. Only the states that have a computed
+    var are pre-fetched (like Child3 in this test), so `get_state` needs to
+    avoid refetching those already-cached states when getting substates,
+    otherwise the set values will be overridden by the freshly deserialized
+    version and lost.
+
+    Explicit regression test for https://github.com/reflex-dev/reflex/issues/2851.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+
+    class Parent(BaseState):
+        """A root state like rx.State."""
+
+        parent_var: int = 0
+
+    class Child(Parent):
+        """A state simulating UpdateVarsInternalState."""
+
+        pass
+
+    class Child2(Parent):
+        """An unconnected child state."""
+
+        pass
+
+    class Child3(Parent):
+        """A child state with a computed var causing it to be pre-fetched.
+
+        If child3_var gets set to a value, and `get_state` erroneously
+        re-fetches it from redis, the value will be lost.
+        """
+
+        child3_var: int = 0
+
+        @rx.var
+        def v(self):
+            pass
+
+    class Grandchild3(Child3):
+        """An extra layer of substate to catch an issue discovered in
+        _determine_missing_parent_states while writing the regression test where
+        invalid parent state names were being constructed.
+        """
+
+        pass
+
+    class GreatGrandchild3(Grandchild3):
+        """Fetching this state wants to also fetch Child3 as a missing parent.
+        However, Child3 should already be cached in the state tree because it
+        has a computed var.
+        """
+
+        pass
+
+    mock_app.state_manager.state = mock_app.state = Parent
+
+    # Get the top level state via unconnected sibling.
+    root = await mock_app.state_manager.get_state(_substate_key(token, Child))
+    # Set value in parent_var to assert it does not get refetched later.
+    root.parent_var = 1
+
+    if isinstance(mock_app.state_manager, StateManagerRedis):
+        # When redis is used, only states with computed vars are pre-fetched.
+        assert "child2" not in root.substates
+        assert "child3" in root.substates  # (due to @rx.var)
+
+    # Get the unconnected sibling state, which will be used to `get_state` other instances.
+    child = root.get_substate(Child.get_full_name().split("."))
+
+    # Get an uncached child state.
+    child2 = await child.get_state(Child2)
+    assert child2.parent_var == 1
+
+    # Set value on already-cached Child3 state (prefetched because it has a Computed Var).
+    child3 = await child.get_state(Child3)
+    child3.child3_var = 1
+
+    # Get uncached great_grandchild3 state.
+    great_grandchild3 = await child.get_state(GreatGrandchild3)
+
+    # Assert that we didn't re-fetch the parent and child3 state from redis
+    assert great_grandchild3.parent_var == 1
+    assert great_grandchild3.child3_var == 1
+
+
 # Save a reference to the rx.State to shadow the name State for testing.
 RxState = State