Bläddra i källkod

improve type support for .get_state (#4623)

* improve type support for .get_state

* dang it darglint
Khaleel Al-Adhami 4 månader sedan
förälder
incheckning
b50b7692b2
2 ändrade filer med 30 tillägg och 5 borttagningar
  1. 26 5
      reflex/state.py
  2. 4 0
      reflex/utils/exceptions.py

+ 26 - 5
reflex/state.py

@@ -104,6 +104,7 @@ from reflex.utils.exceptions import (
     LockExpiredError,
     ReflexRuntimeError,
     SetUndefinedStateVarError,
+    StateMismatchError,
     StateSchemaMismatchError,
     StateSerializationError,
     StateTooLargeError,
@@ -1543,7 +1544,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Return the direct parent of target_state_cls for subsequent linking.
         return parent_state
 
-    def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
+    def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get a state instance from the cache.
 
         Args:
@@ -1551,11 +1552,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         Returns:
             The instance of state_cls associated with this state's client_token.
+
+        Raises:
+            StateMismatchError: If the state instance is not of the expected type.
         """
         root_state = self._get_root_state()
-        return root_state.get_substate(state_cls.get_full_name().split("."))
+        substate = root_state.get_substate(state_cls.get_full_name().split("."))
+        if not isinstance(substate, state_cls):
+            raise StateMismatchError(
+                f"Searched for state {state_cls.get_full_name()} but found {substate}."
+            )
+        return substate
 
-    async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
+    async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get a state instance from redis.
 
         Args:
@@ -1566,6 +1575,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         Raises:
             RuntimeError: If redis is not used in this backend process.
+            StateMismatchError: If the state instance is not of the expected type.
         """
         # Fetch all missing parent states from redis.
         parent_state_of_state_cls = await self._populate_parent_states(state_cls)
@@ -1577,14 +1587,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
                 "(All states should already be available -- this is likely a bug).",
             )
-        return await state_manager.get_state(
+
+        state_in_redis = await state_manager.get_state(
             token=_substate_key(self.router.session.client_token, state_cls),
             top_level=False,
             get_substates=True,
             parent_state=parent_state_of_state_cls,
         )
 
-    async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
+        if not isinstance(state_in_redis, state_cls):
+            raise StateMismatchError(
+                f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
+            )
+
+        return state_in_redis
+
+    async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get an instance of the state associated with this token.
 
         Allows for arbitrary access to sibling states from within an event handler.
@@ -2316,6 +2334,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return state
 
 
+T_STATE = TypeVar("T_STATE", bound=BaseState)
+
+
 class State(BaseState):
     """The app Base State."""
 

+ 4 - 0
reflex/utils/exceptions.py

@@ -163,6 +163,10 @@ class StateSerializationError(ReflexError):
     """Raised when the state cannot be serialized."""
 
 
+class StateMismatchError(ReflexError, ValueError):
+    """Raised when the state retrieved does not match the expected state."""
+
+
 class SystemPackageMissingError(ReflexError):
     """Raised when a system package is missing."""