ソースを参照

[ENG-2287] Avoid fetching same state from redis multiple times (#4055)

* Avoid fetching substates multiple times

In the presence of computed vars, substates may be cached more than once.

* Consolidate logic in StateManagerRedis.get_state

* Suppress StateSchemaMismatchError and create a new state instance.

If the serialized state's schema does not match the current corresponding state
schema, then we have to create a new instance.
Masen Furer 7 ヶ月 前
コミット
7529bb0c64
1 ファイル変更33 行追加29 行削除
  1. 33 29
      reflex/state.py

+ 33 - 29
reflex/state.py

@@ -2916,11 +2916,14 @@ class StateManagerRedis(StateManager):
     # Only warn about each state class size once.
     # Only warn about each state class size once.
     _warned_about_state_size: ClassVar[Set[str]] = set()
     _warned_about_state_size: ClassVar[Set[str]] = set()
 
 
-    async def _get_parent_state(self, token: str) -> BaseState | None:
+    async def _get_parent_state(
+        self, token: str, state: BaseState | None = None
+    ) -> BaseState | None:
         """Get the parent state for the state requested in the token.
         """Get the parent state for the state requested in the token.
 
 
         Args:
         Args:
             token: The token to get the state for (_substate_key).
             token: The token to get the state for (_substate_key).
+            state: The state instance to get parent state for.
 
 
         Returns:
         Returns:
             The parent state for the state requested by the token or None if there is no such parent.
             The parent state for the state requested by the token or None if there is no such parent.
@@ -2929,11 +2932,15 @@ class StateManagerRedis(StateManager):
         client_token, state_path = _split_substate_key(token)
         client_token, state_path = _split_substate_key(token)
         parent_state_name = state_path.rpartition(".")[0]
         parent_state_name = state_path.rpartition(".")[0]
         if parent_state_name:
         if parent_state_name:
+            cached_substates = None
+            if state is not None:
+                cached_substates = [state]
             # Retrieve the parent state to populate event handlers onto this substate.
             # Retrieve the parent state to populate event handlers onto this substate.
             parent_state = await self.get_state(
             parent_state = await self.get_state(
                 token=_substate_key(client_token, parent_state_name),
                 token=_substate_key(client_token, parent_state_name),
                 top_level=False,
                 top_level=False,
                 get_substates=False,
                 get_substates=False,
+                cached_substates=cached_substates,
             )
             )
         return parent_state
         return parent_state
 
 
@@ -2965,6 +2972,8 @@ class StateManagerRedis(StateManager):
         tasks = {}
         tasks = {}
         # Retrieve the necessary substates from redis.
         # Retrieve the necessary substates from redis.
         for substate_cls in fetch_substates:
         for substate_cls in fetch_substates:
+            if substate_cls.get_name() in state.substates:
+                continue
             substate_name = substate_cls.get_name()
             substate_name = substate_cls.get_name()
             tasks[substate_name] = asyncio.create_task(
             tasks[substate_name] = asyncio.create_task(
                 self.get_state(
                 self.get_state(
@@ -2985,6 +2994,7 @@ class StateManagerRedis(StateManager):
         top_level: bool = True,
         top_level: bool = True,
         get_substates: bool = True,
         get_substates: bool = True,
         parent_state: BaseState | None = None,
         parent_state: BaseState | None = None,
+        cached_substates: list[BaseState] | None = None,
     ) -> BaseState:
     ) -> BaseState:
         """Get the state for a token.
         """Get the state for a token.
 
 
@@ -2993,6 +3003,7 @@ class StateManagerRedis(StateManager):
             top_level: If true, return an instance of the top-level state (self.state).
             top_level: If true, return an instance of the top-level state (self.state).
             get_substates: If true, also retrieve substates.
             get_substates: If true, also retrieve substates.
             parent_state: If provided, use this parent_state instead of getting it from redis.
             parent_state: If provided, use this parent_state instead of getting it from redis.
+            cached_substates: If provided, attach these substates to the state.
 
 
         Returns:
         Returns:
             The state for the token.
             The state for the token.
@@ -3010,45 +3021,38 @@ class StateManagerRedis(StateManager):
                 "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
                 "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
             )
             )
 
 
+        # The deserialized or newly created (sub)state instance.
+        state = None
+
         # Fetch the serialized substate from redis.
         # Fetch the serialized substate from redis.
         redis_state = await self.redis.get(token)
         redis_state = await self.redis.get(token)
 
 
         if redis_state is not None:
         if redis_state is not None:
             # Deserialize the substate.
             # Deserialize the substate.
-            state = BaseState._deserialize(data=redis_state)
-
-            # Populate parent state if missing and requested.
-            if parent_state is None:
-                parent_state = await self._get_parent_state(token)
-            # Set up Bidirectional linkage between this state and its parent.
-            if parent_state is not None:
-                parent_state.substates[state.get_name()] = state
-                state.parent_state = parent_state
-            # Populate substates if requested.
-            await self._populate_substates(token, state, all_substates=get_substates)
-
-            # To retain compatibility with previous implementation, by default, we return
-            # the top-level state by chasing `parent_state` pointers up the tree.
-            if top_level:
-                return state._get_root_state()
-            return state
-
-        # TODO: dedupe the following logic with the above block
-        # Key didn't exist so we have to create a new instance for this token.
+            with contextlib.suppress(StateSchemaMismatchError):
+                state = BaseState._deserialize(data=redis_state)
+        if state is None:
+            # Key didn't exist or schema mismatch so create a new instance for this token.
+            state = state_cls(
+                init_substates=False,
+                _reflex_internal_init=True,
+            )
+        # Populate parent state if missing and requested.
         if parent_state is None:
         if parent_state is None:
-            parent_state = await self._get_parent_state(token)
-        # Instantiate the new state class (but don't persist it yet).
-        state = state_cls(
-            parent_state=parent_state,
-            init_substates=False,
-            _reflex_internal_init=True,
-        )
+            parent_state = await self._get_parent_state(token, state)
         # Set up Bidirectional linkage between this state and its parent.
         # Set up Bidirectional linkage between this state and its parent.
         if parent_state is not None:
         if parent_state is not None:
             parent_state.substates[state.get_name()] = state
             parent_state.substates[state.get_name()] = state
             state.parent_state = parent_state
             state.parent_state = parent_state
-        # Populate substates for the newly created state.
+        # Avoid fetching substates multiple times.
+        if cached_substates:
+            for substate in cached_substates:
+                state.substates[substate.get_name()] = substate
+                if substate.parent_state is None:
+                    substate.parent_state = state
+        # Populate substates if requested.
         await self._populate_substates(token, state, all_substates=get_substates)
         await self._populate_substates(token, state, all_substates=get_substates)
+
         # To retain compatibility with previous implementation, by default, we return
         # To retain compatibility with previous implementation, by default, we return
         # the top-level state by chasing `parent_state` pointers up the tree.
         # the top-level state by chasing `parent_state` pointers up the tree.
         if top_level:
         if top_level: