瀏覽代碼

minor State cleanup (#3768)

benedikt-bartscher 9 月之前
父節點
當前提交
2d9380a6fd
共有 2 個文件被更改,包括 23 次插入23 次删除
  1. 22 22
      reflex/state.py
  2. 1 1
      reflex/utils/types.py

+ 22 - 22
reflex/state.py

@@ -55,6 +55,7 @@ from reflex.utils import console, format, prerequisites, types
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
 from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import SerializedType, serialize, serializer
+from reflex.utils.types import override
 from reflex.vars import BaseVar, ComputedVar, Var, computed_var
 
 if TYPE_CHECKING:
@@ -1232,6 +1233,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             parent_states_with_name.append((parent_state.get_full_name(), parent_state))
         return parent_states_with_name
 
+    def _get_root_state(self) -> BaseState:
+        """Get the root state of the state tree.
+
+        Returns:
+            The root state of the state tree.
+        """
+        parent_state = self
+        while parent_state.parent_state is not None:
+            parent_state = parent_state.parent_state
+        return parent_state
+
     async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
         """Populate substates in the tree between the target_state_cls and common ancestor of this state.
 
@@ -1291,10 +1303,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The instance of state_cls associated with this state's client_token.
         """
-        if self.parent_state is None:
-            root_state = self
-        else:
-            root_state = self._get_parent_states()[-1][1]
+        root_state = self._get_root_state()
         return root_state.get_substate(state_cls.get_full_name().split("."))
 
     async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
@@ -1445,9 +1454,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             The valid StateUpdate containing the events and final flag.
         """
         # get the delta from the root of the state tree
-        state = self
-        while state.parent_state is not None:
-            state = state.parent_state
+        state = self._get_root_state()
 
         token = self.router.session.client_token
 
@@ -2368,6 +2375,7 @@ class StateManagerMemory(StateManager):
             "_states_locks": {"exclude": True},
         }
 
+    @override
     async def get_state(self, token: str) -> BaseState:
         """Get the state for a token.
 
@@ -2383,6 +2391,7 @@ class StateManagerMemory(StateManager):
             self.states[token] = self.state(_reflex_internal_init=True)
         return self.states[token]
 
+    @override
     async def set_state(self, token: str, state: BaseState):
         """Set the state for a token.
 
@@ -2392,6 +2401,7 @@ class StateManagerMemory(StateManager):
         """
         pass
 
+    @override
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state for a token while holding exclusive lock.
@@ -2483,19 +2493,6 @@ class StateManagerRedis(StateManager):
     # Only warn about each state class size once.
     _warned_about_state_size: ClassVar[Set[str]] = set()
 
-    def _get_root_state(self, state: BaseState) -> BaseState:
-        """Chase parent_state pointers to find an instance of the top-level state.
-
-        Args:
-            state: The state to start from.
-
-        Returns:
-            An instance of the top-level state (self.state).
-        """
-        while type(state) != self.state and state.parent_state is not None:
-            state = state.parent_state
-        return state
-
     async def _get_parent_state(self, token: str) -> BaseState | None:
         """Get the parent state for the state requested in the token.
 
@@ -2558,6 +2555,7 @@ class StateManagerRedis(StateManager):
         for substate_name, substate_task in tasks.items():
             state.substates[substate_name] = await substate_task
 
+    @override
     async def get_state(
         self,
         token: str,
@@ -2609,7 +2607,7 @@ class StateManagerRedis(StateManager):
             # To retain compatibility with previous implementation, by default, we return
             # the top-level state by chasing `parent_state` pointers up the tree.
             if top_level:
-                return self._get_root_state(state)
+                return state._get_root_state()
             return state
 
         # TODO: dedupe the following logic with the above block
@@ -2631,7 +2629,7 @@ class StateManagerRedis(StateManager):
         # To retain compatibility with previous implementation, by default, we return
         # the top-level state by chasing `parent_state` pointers up the tree.
         if top_level:
-            return self._get_root_state(state)
+            return state._get_root_state()
         return state
 
     def _warn_if_too_large(
@@ -2657,6 +2655,7 @@ class StateManagerRedis(StateManager):
             )
             self._warned_about_state_size.add(state_full_name)
 
+    @override
     async def set_state(
         self,
         token: str,
@@ -2717,6 +2716,7 @@ class StateManagerRedis(StateManager):
         for t in tasks:
             await t
 
+    @override
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state for a token while holding exclusive lock.

+ 1 - 1
reflex/utils/types.py

@@ -49,7 +49,7 @@ from reflex.base import Base
 from reflex.utils import console
 
 if sys.version_info >= (3, 12):
-    from typing import override
+    from typing import override as override
 else:
 
     def override(func: Callable) -> Callable: