浏览代码

[ENG-2287] Avoid fetching same state from redis multiple times (#4055)

* Avoid fetching substates multiple times

In the presence of computed vars, substates may be cached more than once.

* Consolidate logic in StateManagerRedis.get_state

* Suppress StateSchemaMismatchError and create a new state instance.

If the serialized state's schema does not match the current corresponding state
schema, then we have to create a new instance.
Masen Furer 7 月之前
父节点
当前提交
7529bb0c64
共有 1 个文件被更改,包括 33 次插入29 次删除
  1. 33 29
      reflex/state.py

+ 33 - 29
reflex/state.py

@@ -2916,11 +2916,14 @@ class StateManagerRedis(StateManager):
     # Only warn about each state class size once.
     _warned_about_state_size: ClassVar[Set[str]] = set()
 
-    async def _get_parent_state(self, token: str) -> BaseState | None:
+    async def _get_parent_state(
+        self, token: str, state: BaseState | None = None
+    ) -> BaseState | None:
         """Get the parent state for the state requested in the token.
 
         Args:
             token: The token to get the state for (_substate_key).
+            state: The state instance to get parent state for.
 
         Returns:
             The parent state for the state requested by the token or None if there is no such parent.
@@ -2929,11 +2932,15 @@ class StateManagerRedis(StateManager):
         client_token, state_path = _split_substate_key(token)
         parent_state_name = state_path.rpartition(".")[0]
         if parent_state_name:
+            cached_substates = None
+            if state is not None:
+                cached_substates = [state]
             # Retrieve the parent state to populate event handlers onto this substate.
             parent_state = await self.get_state(
                 token=_substate_key(client_token, parent_state_name),
                 top_level=False,
                 get_substates=False,
+                cached_substates=cached_substates,
             )
         return parent_state
 
@@ -2965,6 +2972,8 @@ class StateManagerRedis(StateManager):
         tasks = {}
         # Retrieve the necessary substates from redis.
         for substate_cls in fetch_substates:
+            if substate_cls.get_name() in state.substates:
+                continue
             substate_name = substate_cls.get_name()
             tasks[substate_name] = asyncio.create_task(
                 self.get_state(
@@ -2985,6 +2994,7 @@ class StateManagerRedis(StateManager):
         top_level: bool = True,
         get_substates: bool = True,
         parent_state: BaseState | None = None,
+        cached_substates: list[BaseState] | None = None,
     ) -> BaseState:
         """Get the state for a token.
 
@@ -2993,6 +3003,7 @@ class StateManagerRedis(StateManager):
             top_level: If true, return an instance of the top-level state (self.state).
             get_substates: If true, also retrieve substates.
             parent_state: If provided, use this parent_state instead of getting it from redis.
+            cached_substates: If provided, attach these substates to the state.
 
         Returns:
             The state for the token.
@@ -3010,45 +3021,38 @@ class StateManagerRedis(StateManager):
                 "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
             )
 
+        # The deserialized or newly created (sub)state instance.
+        state = None
+
         # Fetch the serialized substate from redis.
         redis_state = await self.redis.get(token)
 
         if redis_state is not None:
             # Deserialize the substate.
-            state = BaseState._deserialize(data=redis_state)
-
-            # Populate parent state if missing and requested.
-            if parent_state is None:
-                parent_state = await self._get_parent_state(token)
-            # Set up Bidirectional linkage between this state and its parent.
-            if parent_state is not None:
-                parent_state.substates[state.get_name()] = state
-                state.parent_state = parent_state
-            # Populate substates if requested.
-            await self._populate_substates(token, state, all_substates=get_substates)
-
-            # To retain compatibility with previous implementation, by default, we return
-            # the top-level state by chasing `parent_state` pointers up the tree.
-            if top_level:
-                return state._get_root_state()
-            return state
-
-        # TODO: dedupe the following logic with the above block
-        # Key didn't exist so we have to create a new instance for this token.
+            with contextlib.suppress(StateSchemaMismatchError):
+                state = BaseState._deserialize(data=redis_state)
+        if state is None:
+            # Key didn't exist or schema mismatch so create a new instance for this token.
+            state = state_cls(
+                init_substates=False,
+                _reflex_internal_init=True,
+            )
+        # Populate parent state if missing and requested.
         if parent_state is None:
-            parent_state = await self._get_parent_state(token)
-        # Instantiate the new state class (but don't persist it yet).
-        state = state_cls(
-            parent_state=parent_state,
-            init_substates=False,
-            _reflex_internal_init=True,
-        )
+            parent_state = await self._get_parent_state(token, state)
         # Set up Bidirectional linkage between this state and its parent.
         if parent_state is not None:
             parent_state.substates[state.get_name()] = state
             state.parent_state = parent_state
-        # Populate substates for the newly created state.
+        # Avoid fetching substates multiple times.
+        if cached_substates:
+            for substate in cached_substates:
+                state.substates[substate.get_name()] = substate
+                if substate.parent_state is None:
+                    substate.parent_state = state
+        # Populate substates if requested.
         await self._populate_substates(token, state, all_substates=get_substates)
+
         # To retain compatibility with previous implementation, by default, we return
         # the top-level state by chasing `parent_state` pointers up the tree.
         if top_level: