|
@@ -104,6 +104,7 @@ from reflex.utils.exceptions import (
|
|
LockExpiredError,
|
|
LockExpiredError,
|
|
ReflexRuntimeError,
|
|
ReflexRuntimeError,
|
|
SetUndefinedStateVarError,
|
|
SetUndefinedStateVarError,
|
|
|
|
+ StateMismatchError,
|
|
StateSchemaMismatchError,
|
|
StateSchemaMismatchError,
|
|
StateSerializationError,
|
|
StateSerializationError,
|
|
StateTooLargeError,
|
|
StateTooLargeError,
|
|
@@ -1543,7 +1544,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
# Return the direct parent of target_state_cls for subsequent linking.
|
|
# Return the direct parent of target_state_cls for subsequent linking.
|
|
return parent_state
|
|
return parent_state
|
|
|
|
|
|
- def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
|
|
|
|
|
|
+ def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
|
|
"""Get a state instance from the cache.
|
|
"""Get a state instance from the cache.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
@@ -1551,11 +1552,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
The instance of state_cls associated with this state's client_token.
|
|
The instance of state_cls associated with this state's client_token.
|
|
|
|
+
|
|
|
|
+ Raises:
|
|
|
|
+ StateMismatchError: If the state instance is not of the expected type.
|
|
"""
|
|
"""
|
|
root_state = self._get_root_state()
|
|
root_state = self._get_root_state()
|
|
- return root_state.get_substate(state_cls.get_full_name().split("."))
|
|
|
|
|
|
+ substate = root_state.get_substate(state_cls.get_full_name().split("."))
|
|
|
|
+ if not isinstance(substate, state_cls):
|
|
|
|
+ raise StateMismatchError(
|
|
|
|
+ f"Searched for state {state_cls.get_full_name()} but found {substate}."
|
|
|
|
+ )
|
|
|
|
+ return substate
|
|
|
|
|
|
- async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
|
|
|
|
|
|
+ async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
|
|
"""Get a state instance from redis.
|
|
"""Get a state instance from redis.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
@@ -1566,6 +1575,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
|
|
|
Raises:
|
|
Raises:
|
|
RuntimeError: If redis is not used in this backend process.
|
|
RuntimeError: If redis is not used in this backend process.
|
|
|
|
+ StateMismatchError: If the state instance is not of the expected type.
|
|
"""
|
|
"""
|
|
# Fetch all missing parent states from redis.
|
|
# Fetch all missing parent states from redis.
|
|
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
|
|
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
|
|
@@ -1577,14 +1587,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
|
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
|
"(All states should already be available -- this is likely a bug).",
|
|
"(All states should already be available -- this is likely a bug).",
|
|
)
|
|
)
|
|
- return await state_manager.get_state(
|
|
|
|
|
|
+
|
|
|
|
+ state_in_redis = await state_manager.get_state(
|
|
token=_substate_key(self.router.session.client_token, state_cls),
|
|
token=_substate_key(self.router.session.client_token, state_cls),
|
|
top_level=False,
|
|
top_level=False,
|
|
get_substates=True,
|
|
get_substates=True,
|
|
parent_state=parent_state_of_state_cls,
|
|
parent_state=parent_state_of_state_cls,
|
|
)
|
|
)
|
|
|
|
|
|
- async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
|
|
|
|
|
|
+ if not isinstance(state_in_redis, state_cls):
|
|
|
|
+ raise StateMismatchError(
|
|
|
|
+ f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ return state_in_redis
|
|
|
|
+
|
|
|
|
+ async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
|
|
"""Get an instance of the state associated with this token.
|
|
"""Get an instance of the state associated with this token.
|
|
|
|
|
|
Allows for arbitrary access to sibling states from within an event handler.
|
|
Allows for arbitrary access to sibling states from within an event handler.
|
|
@@ -2316,6 +2334,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
return state
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
+T_STATE = TypeVar("T_STATE", bound=BaseState)
|
|
|
|
+
|
|
|
|
+
|
|
class State(BaseState):
|
|
class State(BaseState):
|
|
"""The app Base State."""
|
|
"""The app Base State."""
|
|
|
|
|