Przeglądaj źródła

Optimize StateManagerDisk (#4056)

* Simplify StateManagerDisk implementation

* Act more like the memory state manager and only track the root state in self.states
* .load_state always loads a single state or returns None
* .populate_states is the new entry point in loading from disk and it only occurs
  when the root state is not known
* much fast

* StateManagerDisk now acts much more like StateManagerMemory

Treat StateManagerDisk like StateManagerMemory for AppHarness

* Handle root_state deserialized from disk

In this case, we need to initialize the whole state tree, so any non-persistent
states will still get default values, whereas on-disk states will overwrite the
defaults.

* Cache root_state under client_token for StateManagerMemory compatibility

Mainly this just makes it easier for us to write tests that work against either
Disk or Memory state managers.
Masen Furer 7 miesięcy temu
rodzic
commit
aa69234b76
3 zmienionych plików z 36 dodań i 45 usunięć
  1. 32 30
      reflex/state.py
  2. 0 2
      reflex/testing.py
  3. 4 13
      tests/units/test_state.py

+ 32 - 30
reflex/state.py

@@ -2711,34 +2711,24 @@ class StateManagerDisk(StateManager):
             self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
         ).absolute()
 
-    async def load_state(self, token: str, root_state: BaseState) -> BaseState:
+    async def load_state(self, token: str) -> BaseState | None:
         """Load a state object based on the provided token.
 
         Args:
             token: The token used to identify the state object.
-            root_state: The root state object.
 
         Returns:
-            The loaded state object.
+            The loaded state object or None.
         """
-        if token in self.states:
-            return self.states[token]
-
-        client_token, substate_address = _split_substate_key(token)
-
         token_path = self.token_path(token)
 
         if token_path.exists():
             try:
                 with token_path.open(mode="rb") as file:
-                    substate = BaseState._deserialize(fp=file)
-                    await self.populate_substates(client_token, substate, root_state)
-                    return substate
+                    return BaseState._deserialize(fp=file)
             except Exception:
                 pass
 
-        return root_state.get_substate(substate_address.split(".")[1:])
-
     async def populate_substates(
         self, client_token: str, state: BaseState, root_state: BaseState
     ):
@@ -2752,10 +2742,13 @@ class StateManagerDisk(StateManager):
         for substate in state.get_substates():
             substate_token = _substate_key(client_token, substate)
 
-            substate = await self.load_state(substate_token, root_state)
+            instance = await self.load_state(substate_token)
+            if instance is None:
+                instance = await root_state.get_state(substate)
+            state.substates[substate.get_name()] = instance
+            instance.parent_state = state
 
-            state.substates[substate.get_name()] = substate
-            substate.parent_state = state
+            await self.populate_substates(client_token, instance, root_state)
 
     @override
     async def get_state(
@@ -2770,15 +2763,24 @@ class StateManagerDisk(StateManager):
         Returns:
             The state for the token.
         """
-        client_token, substate_address = _split_substate_key(token)
-
-        root_state_token = _substate_key(client_token, substate_address.split(".")[0])
-        root_state = self.states.get(root_state_token)
+        client_token = _split_substate_key(token)[0]
+        root_state = self.states.get(client_token)
+        if root_state is not None:
+            # Retrieved state from memory.
+            return root_state
+
+        # Deserialize root state from disk.
+        root_state = await self.load_state(_substate_key(client_token, self.state))
+        # Create a new root state tree with all substates instantiated.
+        fresh_root_state = self.state(_reflex_internal_init=True)
         if root_state is None:
-            # Create a new root state which will be persisted in the next set_state call.
-            root_state = self.state(_reflex_internal_init=True)
-
-        return await self.load_state(root_state_token, root_state)
+            root_state = fresh_root_state
+        else:
+            # Ensure all substates exist, even if they were not serialized previously.
+            root_state.substates = fresh_root_state.substates
+        self.states[client_token] = root_state
+        await self.populate_substates(client_token, root_state, root_state)
+        return root_state
 
     async def set_state_for_substate(self, client_token: str, substate: BaseState):
         """Set the state for a substate.
@@ -2789,12 +2791,12 @@ class StateManagerDisk(StateManager):
         """
         substate_token = _substate_key(client_token, substate)
 
-        self.states[substate_token] = substate
-
-        state_dilled = substate._serialize()
-        if not self.states_directory.exists():
-            self.states_directory.mkdir(parents=True, exist_ok=True)
-        self.token_path(substate_token).write_bytes(state_dilled)
+        if substate._get_was_touched():
+            substate._was_touched = False  # Reset the touched flag after serializing.
+            pickle_state = substate._serialize()
+            if not self.states_directory.exists():
+                self.states_directory.mkdir(parents=True, exist_ok=True)
+            self.token_path(substate_token).write_bytes(pickle_state)
 
         for substate_substate in substate.substates.values():
             await self.set_state_for_substate(client_token, substate_substate)

+ 0 - 2
reflex/testing.py

@@ -292,8 +292,6 @@ class AppHarness:
         if isinstance(self.app_instance._state_manager, StateManagerRedis):
             # Create our own redis connection for testing.
             self.state_manager = StateManagerRedis.create(self.app_instance.state)
-        elif isinstance(self.app_instance._state_manager, StateManagerDisk):
-            self.state_manager = StateManagerDisk.create(self.app_instance.state)
         else:
             self.state_manager = self.app_instance._state_manager
 

+ 4 - 13
tests/units/test_state.py

@@ -1884,11 +1884,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     async with sp:
         assert sp._self_actx is not None
         assert sp._self_mutable  # proxy is mutable inside context
-        if isinstance(mock_app.state_manager, StateManagerMemory):
+        if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
             # For in-process store, only one instance of the state exists
             assert sp.__wrapped__ is grandchild_state
         else:
-            # When redis or disk is used, a new+updated instance is assigned to the proxy
+            # When redis is used, a new+updated instance is assigned to the proxy
             assert sp.__wrapped__ is not grandchild_state
         sp.value2 = "42"
     assert not sp._self_mutable  # proxy is not mutable after exiting context
@@ -1899,7 +1899,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     gotten_state = await mock_app.state_manager.get_state(
         _substate_key(grandchild_state.router.session.client_token, grandchild_state)
     )
-    if isinstance(mock_app.state_manager, StateManagerMemory):
+    if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
         # For in-process store, only one instance of the state exists
         assert gotten_state is parent_state
     else:
@@ -2922,7 +2922,7 @@ async def test_get_state(mock_app: rx.App, token: str):
         _substate_key(token, ChildState2)
     )
     assert isinstance(new_test_state, TestState)
-    if isinstance(mock_app.state_manager, StateManagerMemory):
+    if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
         # In memory, it's the same instance
         assert new_test_state is test_state
         test_state._clean()
@@ -2932,15 +2932,6 @@ async def test_get_state(mock_app: rx.App, token: str):
             ChildState2.get_name(),
             ChildState3.get_name(),
         )
-    elif isinstance(mock_app.state_manager, StateManagerDisk):
-        # On disk, it's a new instance
-        assert new_test_state is not test_state
-        # All substates are available
-        assert tuple(sorted(new_test_state.substates)) == (
-            ChildState.get_name(),
-            ChildState2.get_name(),
-            ChildState3.get_name(),
-        )
     else:
         # With redis, we get a whole new instance
         assert new_test_state is not test_state