فهرست منبع

improve type support for .get_state (#4623)

* improve type support for .get_state

* dang it darglint
Khaleel Al-Adhami 4 ماه پیش
والد
کامیت
b50b7692b2
2فایلهای تغییر یافته به همراه30 افزوده شده و 5 حذف شده
  1. 26 5
      reflex/state.py
  2. 4 0
      reflex/utils/exceptions.py

+ 26 - 5
reflex/state.py

@@ -104,6 +104,7 @@ from reflex.utils.exceptions import (
     LockExpiredError,
     LockExpiredError,
     ReflexRuntimeError,
     ReflexRuntimeError,
     SetUndefinedStateVarError,
     SetUndefinedStateVarError,
+    StateMismatchError,
     StateSchemaMismatchError,
     StateSchemaMismatchError,
     StateSerializationError,
     StateSerializationError,
     StateTooLargeError,
     StateTooLargeError,
@@ -1543,7 +1544,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Return the direct parent of target_state_cls for subsequent linking.
         # Return the direct parent of target_state_cls for subsequent linking.
         return parent_state
         return parent_state
 
 
-    def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
+    def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get a state instance from the cache.
         """Get a state instance from the cache.
 
 
         Args:
         Args:
@@ -1551,11 +1552,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         Returns:
         Returns:
             The instance of state_cls associated with this state's client_token.
             The instance of state_cls associated with this state's client_token.
+
+        Raises:
+            StateMismatchError: If the state instance is not of the expected type.
         """
         """
         root_state = self._get_root_state()
         root_state = self._get_root_state()
-        return root_state.get_substate(state_cls.get_full_name().split("."))
+        substate = root_state.get_substate(state_cls.get_full_name().split("."))
+        if not isinstance(substate, state_cls):
+            raise StateMismatchError(
+                f"Searched for state {state_cls.get_full_name()} but found {substate}."
+            )
+        return substate
 
 
-    async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
+    async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get a state instance from redis.
         """Get a state instance from redis.
 
 
         Args:
         Args:
@@ -1566,6 +1575,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         Raises:
         Raises:
             RuntimeError: If redis is not used in this backend process.
             RuntimeError: If redis is not used in this backend process.
+            StateMismatchError: If the state instance is not of the expected type.
         """
         """
         # Fetch all missing parent states from redis.
         # Fetch all missing parent states from redis.
         parent_state_of_state_cls = await self._populate_parent_states(state_cls)
         parent_state_of_state_cls = await self._populate_parent_states(state_cls)
@@ -1577,14 +1587,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
                 f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
                 "(All states should already be available -- this is likely a bug).",
                 "(All states should already be available -- this is likely a bug).",
             )
             )
-        return await state_manager.get_state(
+
+        state_in_redis = await state_manager.get_state(
             token=_substate_key(self.router.session.client_token, state_cls),
             token=_substate_key(self.router.session.client_token, state_cls),
             top_level=False,
             top_level=False,
             get_substates=True,
             get_substates=True,
             parent_state=parent_state_of_state_cls,
             parent_state=parent_state_of_state_cls,
         )
         )
 
 
-    async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
+        if not isinstance(state_in_redis, state_cls):
+            raise StateMismatchError(
+                f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
+            )
+
+        return state_in_redis
+
+    async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get an instance of the state associated with this token.
         """Get an instance of the state associated with this token.
 
 
         Allows for arbitrary access to sibling states from within an event handler.
         Allows for arbitrary access to sibling states from within an event handler.
@@ -2316,6 +2334,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return state
         return state
 
 
 
 
+T_STATE = TypeVar("T_STATE", bound=BaseState)
+
+
 class State(BaseState):
 class State(BaseState):
     """The app Base State."""
     """The app Base State."""
 
 

+ 4 - 0
reflex/utils/exceptions.py

@@ -163,6 +163,10 @@ class StateSerializationError(ReflexError):
     """Raised when the state cannot be serialized."""
     """Raised when the state cannot be serialized."""
 
 
 
 
+class StateMismatchError(ReflexError, ValueError):
+    """Raised when the state retrieved does not match the expected state."""
+
+
 class SystemPackageMissingError(ReflexError):
 class SystemPackageMissingError(ReflexError):
     """Raised when a system package is missing."""
     """Raised when a system package is missing."""