|
@@ -213,21 +213,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# The router data for the current page
|
|
|
router: RouterData = RouterData()
|
|
|
|
|
|
- def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ *args,
|
|
|
+ parent_state: BaseState | None = None,
|
|
|
+ init_substates: bool = True,
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
"""Initialize the state.
|
|
|
|
|
|
Args:
|
|
|
*args: The args to pass to the Pydantic init method.
|
|
|
parent_state: The parent state.
|
|
|
+ init_substates: Whether to initialize the substates in this instance.
|
|
|
**kwargs: The kwargs to pass to the Pydantic init method.
|
|
|
|
|
|
"""
|
|
|
kwargs["parent_state"] = parent_state
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
- # Setup the substates.
|
|
|
- for substate in self.get_substates():
|
|
|
- self.substates[substate.get_name()] = substate(parent_state=self)
|
|
|
+ # Setup the substates (for memory state manager only).
|
|
|
+ if init_substates:
|
|
|
+ for substate in self.get_substates():
|
|
|
+ self.substates[substate.get_name()] = substate(parent_state=self)
|
|
|
# Convert the event handlers to functions.
|
|
|
self._init_event_handlers()
|
|
|
|
|
@@ -1005,7 +1013,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
for substate in self.substates.values():
|
|
|
substate._reset_client_storage()
|
|
|
|
|
|
- def get_substate(self, path: Sequence[str]) -> BaseState | None:
|
|
|
+ def get_substate(self, path: Sequence[str]) -> BaseState:
|
|
|
"""Get the substate.
|
|
|
|
|
|
Args:
|
|
@@ -1260,6 +1268,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# Recursively find the substate deltas.
|
|
|
substates = self.substates
|
|
|
for substate in self.dirty_substates.union(self._always_dirty_substates):
|
|
|
+ if substate not in substates:
|
|
|
+ continue # substate not loaded at this time, no delta
|
|
|
delta.update(substates[substate].get_delta())
|
|
|
|
|
|
# Format the delta.
|
|
@@ -1287,6 +1297,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
for var in self.dirty_vars:
|
|
|
for substate_name in self._substate_var_dependencies[var]:
|
|
|
self.dirty_substates.add(substate_name)
|
|
|
+ if substate_name not in substates:
|
|
|
+ continue
|
|
|
substate = substates[substate_name]
|
|
|
substate.dirty_vars.add(var)
|
|
|
substate._mark_dirty()
|
|
@@ -1295,6 +1307,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
"""Reset the dirty vars."""
|
|
|
# Recursively clean the substates.
|
|
|
for substate in self.dirty_substates:
|
|
|
+ if substate not in self.substates:
|
|
|
+ continue
|
|
|
self.substates[substate]._clean()
|
|
|
|
|
|
# Clean this state.
|
|
@@ -1380,6 +1394,24 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
"""
|
|
|
pass
|
|
|
|
|
|
+ def __getstate__(self):
|
|
|
+ """Get the state for redis serialization.
|
|
|
+
|
|
|
+ This method is called by cloudpickle to serialize the object.
|
|
|
+
|
|
|
+ It explicitly removes parent_state and substates because those are serialized separately
|
|
|
+ by the StateManagerRedis to allow for better horizontal scaling as state size increases.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state dict for serialization.
|
|
|
+ """
|
|
|
+ state = super().__getstate__()
|
|
|
+ # Never serialize parent_state or substates
|
|
|
+ state["__dict__"] = state["__dict__"].copy()
|
|
|
+ state["__dict__"]["parent_state"] = None
|
|
|
+ state["__dict__"]["substates"] = {}
|
|
|
+ return state
|
|
|
+
|
|
|
|
|
|
class State(BaseState):
|
|
|
"""The app Base State."""
|
|
@@ -1479,6 +1511,8 @@ class StateProxy(wrapt.ObjectProxy):
|
|
|
"""
|
|
|
self._self_actx = self._self_app.modify_state(
|
|
|
self.__wrapped__.router.session.client_token
|
|
|
+ + "_"
|
|
|
+ + ".".join(self._self_substate_path)
|
|
|
)
|
|
|
mutable_state = await self._self_actx.__aenter__()
|
|
|
super().__setattr__(
|
|
@@ -1675,6 +1709,8 @@ class StateManagerMemory(StateManager):
|
|
|
Returns:
|
|
|
The state for the token.
|
|
|
"""
|
|
|
+ # Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
+ token = token.partition("_")[0]
|
|
|
if token not in self.states:
|
|
|
self.states[token] = self.state()
|
|
|
return self.states[token]
|
|
@@ -1698,6 +1734,8 @@ class StateManagerMemory(StateManager):
|
|
|
Yields:
|
|
|
The state for the token.
|
|
|
"""
|
|
|
+ # Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
+ token = token.partition("_")[0]
|
|
|
if token not in self._states_locks:
|
|
|
async with self._state_manager_lock:
|
|
|
if token not in self._states_locks:
|
|
@@ -1737,23 +1775,104 @@ class StateManagerRedis(StateManager):
|
|
|
b"evicted",
|
|
|
}
|
|
|
|
|
|
- async def get_state(self, token: str) -> BaseState:
|
|
|
+ async def get_state(
|
|
|
+ self,
|
|
|
+ token: str,
|
|
|
+ top_level: bool = True,
|
|
|
+ get_substates: bool = True,
|
|
|
+ parent_state: BaseState | None = None,
|
|
|
+ ) -> BaseState:
|
|
|
"""Get the state for a token.
|
|
|
|
|
|
Args:
|
|
|
token: The token to get the state for.
|
|
|
+ top_level: If true, return an instance of the top-level state.
|
|
|
+ get_substates: If true, also retrieve substates
|
|
|
+ parent_state: If provided, use this parent_state instead of getting it from redis.
|
|
|
|
|
|
Returns:
|
|
|
The state for the token.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ RuntimeError: when the state_cls is not specified in the token
|
|
|
"""
|
|
|
+ # Split the actual token from the fully qualified substate name.
|
|
|
+ client_token, _, state_path = token.partition("_")
|
|
|
+ if state_path:
|
|
|
+ # Get the State class associated with the given path.
|
|
|
+ state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
|
|
|
+ else:
|
|
|
+ raise RuntimeError(
|
|
|
+ "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Fetch the serialized substate from redis.
|
|
|
redis_state = await self.redis.get(token)
|
|
|
- if redis_state is None:
|
|
|
- await self.set_state(token, self.state())
|
|
|
- return await self.get_state(token)
|
|
|
- return cloudpickle.loads(redis_state)
|
|
|
+
|
|
|
+ if redis_state is not None:
|
|
|
+ # Deserialize the substate.
|
|
|
+ state = cloudpickle.loads(redis_state)
|
|
|
+
|
|
|
+ # Populate parent and substates if requested.
|
|
|
+ if parent_state is None:
|
|
|
+ # Retrieve the parent state from redis.
|
|
|
+ parent_state_name = state_path.rpartition(".")[0]
|
|
|
+ if parent_state_name:
|
|
|
+ parent_state_key = token.rpartition(".")[0]
|
|
|
+ parent_state = await self.get_state(
|
|
|
+ parent_state_key, top_level=False, get_substates=False
|
|
|
+ )
|
|
|
+ # Set up Bidirectional linkage between this state and its parent.
|
|
|
+ if parent_state is not None:
|
|
|
+ parent_state.substates[state.get_name()] = state
|
|
|
+ state.parent_state = parent_state
|
|
|
+ if get_substates:
|
|
|
+ # Retrieve all substates from redis.
|
|
|
+ for substate_cls in state_cls.get_substates():
|
|
|
+ substate_name = substate_cls.get_name()
|
|
|
+ substate_key = token + "." + substate_name
|
|
|
+ state.substates[substate_name] = await self.get_state(
|
|
|
+ substate_key, top_level=False, parent_state=state
|
|
|
+ )
|
|
|
+ # To retain compatibility with previous implementation, by default, we return
|
|
|
+ # the top-level state by chasing `parent_state` pointers up the tree.
|
|
|
+ if top_level:
|
|
|
+ while type(state) != self.state and state.parent_state is not None:
|
|
|
+ state = state.parent_state
|
|
|
+ return state
|
|
|
+
|
|
|
+ # Key didn't exist so we have to create a new entry for this token.
|
|
|
+ if parent_state is None:
|
|
|
+ parent_state_name = state_path.rpartition(".")[0]
|
|
|
+ if parent_state_name:
|
|
|
+ # Retrieve the parent state to populate event handlers onto this substate.
|
|
|
+ parent_state_key = client_token + "_" + parent_state_name
|
|
|
+ parent_state = await self.get_state(
|
|
|
+ parent_state_key, top_level=False, get_substates=False
|
|
|
+ )
|
|
|
+ # Persist the new state class to redis.
|
|
|
+ await self.set_state(
|
|
|
+ token,
|
|
|
+ state_cls(
|
|
|
+ parent_state=parent_state,
|
|
|
+ init_substates=False,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ # After creating the state key, recursively call `get_state` to populate substates.
|
|
|
+ return await self.get_state(
|
|
|
+ token,
|
|
|
+ top_level=top_level,
|
|
|
+ get_substates=get_substates,
|
|
|
+ parent_state=parent_state,
|
|
|
+ )
|
|
|
|
|
|
async def set_state(
|
|
|
- self, token: str, state: BaseState, lock_id: bytes | None = None
|
|
|
+ self,
|
|
|
+ token: str,
|
|
|
+ state: BaseState,
|
|
|
+ lock_id: bytes | None = None,
|
|
|
+ set_substates: bool = True,
|
|
|
+ set_parent_state: bool = True,
|
|
|
):
|
|
|
"""Set the state for a token.
|
|
|
|
|
@@ -1761,11 +1880,13 @@ class StateManagerRedis(StateManager):
|
|
|
token: The token to set the state for.
|
|
|
state: The state to set.
|
|
|
lock_id: If provided, the lock_key must be set to this value to set the state.
|
|
|
+ set_substates: If True, write substates to redis
|
|
|
+ set_parent_state: If True, write parent state to redis
|
|
|
|
|
|
Raises:
|
|
|
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
|
|
|
"""
|
|
|
- # check that we're holding the lock
|
|
|
+ # Check that we're holding the lock.
|
|
|
if (
|
|
|
lock_id is not None
|
|
|
and await self.redis.get(self._lock_key(token)) != lock_id
|
|
@@ -1775,6 +1896,27 @@ class StateManagerRedis(StateManager):
|
|
|
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
|
|
"or use `@rx.background` decorator for long-running tasks."
|
|
|
)
|
|
|
+ # Find the substate associated with the token.
|
|
|
+ state_path = token.partition("_")[2]
|
|
|
+ if state_path and state.get_full_name() != state_path:
|
|
|
+ state = state.get_substate(tuple(state_path.split(".")))
|
|
|
+ # Persist the parent state separately, if requested.
|
|
|
+ if state.parent_state is not None and set_parent_state:
|
|
|
+ parent_state_key = token.rpartition(".")[0]
|
|
|
+ await self.set_state(
|
|
|
+ parent_state_key,
|
|
|
+ state.parent_state,
|
|
|
+ lock_id=lock_id,
|
|
|
+ set_substates=False,
|
|
|
+ )
|
|
|
+ # Persist the substates separately, if requested.
|
|
|
+ if set_substates:
|
|
|
+ for substate_name, substate in state.substates.items():
|
|
|
+ substate_key = token + "." + substate_name
|
|
|
+ await self.set_state(
|
|
|
+ substate_key, substate, lock_id=lock_id, set_parent_state=False
|
|
|
+ )
|
|
|
+ # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
|
|
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
@@ -1802,7 +1944,9 @@ class StateManagerRedis(StateManager):
|
|
|
Returns:
|
|
|
The redis lock key for the token.
|
|
|
"""
|
|
|
- return f"{token}_lock".encode()
|
|
|
+ # All substates share the same lock domain, so ignore any substate path suffix.
|
|
|
+ client_token = token.partition("_")[0]
|
|
|
+ return f"{client_token}_lock".encode()
|
|
|
|
|
|
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
|
|
|
"""Try to get a redis lock for a token.
|