|
@@ -8,7 +8,6 @@ import copy
|
|
|
import functools
|
|
|
import inspect
|
|
|
import json
|
|
|
-import os
|
|
|
import traceback
|
|
|
import urllib.parse
|
|
|
import uuid
|
|
@@ -45,6 +44,7 @@ from reflex.event import (
|
|
|
)
|
|
|
from reflex.utils import console, format, prerequisites, types
|
|
|
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
|
|
+from reflex.utils.exec import is_testing_env
|
|
|
from reflex.utils.serializers import SerializedType, serialize, serializer
|
|
|
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
|
|
|
|
|
@@ -151,9 +151,45 @@ RESERVED_BACKEND_VAR_NAMES = {
|
|
|
"_substate_var_dependencies",
|
|
|
"_always_dirty_computed_vars",
|
|
|
"_always_dirty_substates",
|
|
|
+ "_was_touched",
|
|
|
}
|
|
|
|
|
|
|
|
|
+def _substate_key(
|
|
|
+ token: str,
|
|
|
+ state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
|
|
|
+) -> str:
|
|
|
+ """Get the substate key.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token of the state.
|
|
|
+ state_cls_or_name: The state class/instance or name or sequence of name parts.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The substate key.
|
|
|
+ """
|
|
|
+ if isinstance(state_cls_or_name, BaseState) or (
|
|
|
+ isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
|
|
|
+ ):
|
|
|
+ state_cls_or_name = state_cls_or_name.get_full_name()
|
|
|
+ elif isinstance(state_cls_or_name, (list, tuple)):
|
|
|
+ state_cls_or_name = ".".join(state_cls_or_name)
|
|
|
+ return f"{token}_{state_cls_or_name}"
|
|
|
+
|
|
|
+
|
|
|
+def _split_substate_key(substate_key: str) -> tuple[str, str]:
|
|
|
+ """Split the substate key into token and state name.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ substate_key: The substate key.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple of token and state name.
|
|
|
+ """
|
|
|
+ token, _, state_name = substate_key.partition("_")
|
|
|
+ return token, state_name
|
|
|
+
|
|
|
+
|
|
|
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
"""The state of the app."""
|
|
|
|
|
@@ -214,29 +250,46 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# The router data for the current page
|
|
|
router: RouterData = RouterData()
|
|
|
|
|
|
+ # Whether the state has ever been touched since instantiation.
|
|
|
+ _was_touched: bool = False
|
|
|
+
|
|
|
def __init__(
|
|
|
self,
|
|
|
*args,
|
|
|
parent_state: BaseState | None = None,
|
|
|
init_substates: bool = True,
|
|
|
+ _reflex_internal_init: bool = False,
|
|
|
**kwargs,
|
|
|
):
|
|
|
"""Initialize the state.
|
|
|
|
|
|
+ DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead.
|
|
|
+
|
|
|
Args:
|
|
|
*args: The args to pass to the Pydantic init method.
|
|
|
parent_state: The parent state.
|
|
|
init_substates: Whether to initialize the substates in this instance.
|
|
|
+ _reflex_internal_init: A flag to indicate that the state is being initialized by the framework.
|
|
|
**kwargs: The kwargs to pass to the Pydantic init method.
|
|
|
|
|
|
+ Raises:
|
|
|
+ RuntimeError: If the state is instantiated directly by end user.
|
|
|
"""
|
|
|
+ if not _reflex_internal_init and not is_testing_env():
|
|
|
+ raise RuntimeError(
|
|
|
+ "State classes should not be instantiated directly in a Reflex app. "
|
|
|
+ "See https://reflex.dev/docs/state for further information."
|
|
|
+ )
|
|
|
kwargs["parent_state"] = parent_state
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
# Setup the substates (for memory state manager only).
|
|
|
if init_substates:
|
|
|
for substate in self.get_substates():
|
|
|
- self.substates[substate.get_name()] = substate(parent_state=self)
|
|
|
+ self.substates[substate.get_name()] = substate(
|
|
|
+ parent_state=self,
|
|
|
+ _reflex_internal_init=True,
|
|
|
+ )
|
|
|
# Convert the event handlers to functions.
|
|
|
self._init_event_handlers()
|
|
|
|
|
@@ -287,7 +340,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
Raises:
|
|
|
ValueError: If a substate class shadows another.
|
|
|
"""
|
|
|
- is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
|
|
|
super().__init_subclass__(**kwargs)
|
|
|
# Event handlers should not shadow builtin state methods.
|
|
|
cls._check_overridden_methods()
|
|
@@ -295,6 +347,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# Reset subclass tracking for this class.
|
|
|
cls.class_subclasses = set()
|
|
|
|
|
|
+ # Reset dirty substate tracking for this class.
|
|
|
+ cls._always_dirty_substates = set()
|
|
|
+
|
|
|
# Get the parent vars.
|
|
|
parent_state = cls.get_parent_state()
|
|
|
if parent_state is not None:
|
|
@@ -303,7 +358,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
|
|
|
# Check if another substate class with the same name has already been defined.
|
|
|
if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
|
|
|
- if is_testing_env:
|
|
|
+ if is_testing_env():
|
|
|
# Clear existing subclass with same name when app is reloaded via
|
|
|
# utils.prerequisites.get_app(reload=True)
|
|
|
parent_state.class_subclasses = set(
|
|
@@ -325,6 +380,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
name: value
|
|
|
for name, value in cls.__dict__.items()
|
|
|
if types.is_backend_variable(name, cls)
|
|
|
+ and name not in RESERVED_BACKEND_VAR_NAMES
|
|
|
and name not in cls.inherited_backend_vars
|
|
|
and not isinstance(value, FunctionType)
|
|
|
}
|
|
@@ -484,7 +540,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
)
|
|
|
|
|
|
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
|
|
|
- cls._always_dirty_substates = set()
|
|
|
if cls._always_dirty_computed_vars:
|
|
|
# Tell parent classes that this substate has always dirty computed vars
|
|
|
state_name = cls.get_name()
|
|
@@ -923,8 +978,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
**super().__getattribute__("inherited_vars"),
|
|
|
**super().__getattribute__("inherited_backend_vars"),
|
|
|
}
|
|
|
- if name in inherited_vars:
|
|
|
- return getattr(super().__getattribute__("parent_state"), name)
|
|
|
+
|
|
|
+ # For now, handle router_data updates as a special case.
|
|
|
+ if name in inherited_vars or name == constants.ROUTER_DATA:
|
|
|
+ parent_state = super().__getattribute__("parent_state")
|
|
|
+ if parent_state is not None:
|
|
|
+ return getattr(parent_state, name)
|
|
|
|
|
|
backend_vars = super().__getattribute__("_backend_vars")
|
|
|
if name in backend_vars:
|
|
@@ -980,9 +1039,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
if name == constants.ROUTER_DATA:
|
|
|
self.dirty_vars.add(name)
|
|
|
self._mark_dirty()
|
|
|
- # propagate router_data updates down the state tree
|
|
|
- for substate in self.substates.values():
|
|
|
- setattr(substate, name, value)
|
|
|
|
|
|
def reset(self):
|
|
|
"""Reset all the base vars to their default values."""
|
|
@@ -1036,6 +1092,170 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
raise ValueError(f"Invalid path: {path}")
|
|
|
return self.substates[path[0]].get_substate(path[1:])
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
|
|
|
+ """Find the name of the nearest common ancestor shared by this and the other state.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ other: The other state.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Full name of the nearest common ancestor.
|
|
|
+ """
|
|
|
+ common_ancestor_parts = []
|
|
|
+ for part1, part2 in zip(
|
|
|
+ cls.get_full_name().split("."),
|
|
|
+ other.get_full_name().split("."),
|
|
|
+ ):
|
|
|
+ if part1 != part2:
|
|
|
+ break
|
|
|
+ common_ancestor_parts.append(part1)
|
|
|
+ return ".".join(common_ancestor_parts)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _determine_missing_parent_states(
|
|
|
+ cls, target_state_cls: Type[BaseState]
|
|
|
+ ) -> tuple[str, list[str]]:
|
|
|
+ """Determine the missing parent states between the target_state_cls and common ancestor of this state.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ target_state_cls: The class of the state to find missing parent states for.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The name of the common ancestor and the list of missing parent states.
|
|
|
+ """
|
|
|
+ common_ancestor_name = cls._get_common_ancestor(target_state_cls)
|
|
|
+ common_ancestor_parts = common_ancestor_name.split(".")
|
|
|
+ target_state_parts = tuple(target_state_cls.get_full_name().split("."))
|
|
|
+ relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
|
|
|
+
|
|
|
+ # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
|
|
|
+ fetch_parent_states = [common_ancestor_name]
|
|
|
+ for ix, relative_parent_state_name in enumerate(relative_target_state_parts):
|
|
|
+ fetch_parent_states.append(
|
|
|
+ ".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name])
|
|
|
+ )
|
|
|
+
|
|
|
+ return common_ancestor_name, fetch_parent_states[1:-1]
|
|
|
+
|
|
|
+ def _get_parent_states(self) -> list[tuple[str, BaseState]]:
|
|
|
+ """Get all parent state instances up to the root of the state tree.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ A list of tuples containing the name and the instance of each parent state.
|
|
|
+ """
|
|
|
+ parent_states_with_name = []
|
|
|
+ parent_state = self
|
|
|
+ while parent_state.parent_state is not None:
|
|
|
+ parent_state = parent_state.parent_state
|
|
|
+ parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
|
|
+ return parent_states_with_name
|
|
|
+
|
|
|
+ async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
|
|
|
+ """Populate substates in the tree between the target_state_cls and common ancestor of this state.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ target_state_cls: The class of the state to populate parent states for.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The parent state instance of target_state_cls.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ RuntimeError: If redis is not used in this backend process.
|
|
|
+ """
|
|
|
+ state_manager = get_state_manager()
|
|
|
+ if not isinstance(state_manager, StateManagerRedis):
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
|
|
|
+ "(All states should already be available -- this is likely a bug).",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Find the missing parent states up to the common ancestor.
|
|
|
+ (
|
|
|
+ common_ancestor_name,
|
|
|
+ missing_parent_states,
|
|
|
+ ) = self._determine_missing_parent_states(target_state_cls)
|
|
|
+
|
|
|
+ # Fetch all missing parent states and link them up to the common ancestor.
|
|
|
+ parent_states_by_name = dict(self._get_parent_states())
|
|
|
+ parent_state = parent_states_by_name[common_ancestor_name]
|
|
|
+ for parent_state_name in missing_parent_states:
|
|
|
+ parent_state = await state_manager.get_state(
|
|
|
+ token=_substate_key(
|
|
|
+ self.router.session.client_token, parent_state_name
|
|
|
+ ),
|
|
|
+ top_level=False,
|
|
|
+ get_substates=False,
|
|
|
+ parent_state=parent_state,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Return the direct parent of target_state_cls for subsequent linking.
|
|
|
+ return parent_state
|
|
|
+
|
|
|
+ def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
|
|
|
+ """Get a state instance from the cache.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state_cls: The class of the state.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The instance of state_cls associated with this state's client_token.
|
|
|
+ """
|
|
|
+ if self.parent_state is None:
|
|
|
+ root_state = self
|
|
|
+ else:
|
|
|
+ root_state = self._get_parent_states()[-1][1]
|
|
|
+ return root_state.get_substate(state_cls.get_full_name().split("."))
|
|
|
+
|
|
|
+ async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
|
|
|
+ """Get a state instance from redis.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state_cls: The class of the state.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The instance of state_cls associated with this state's client_token.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ RuntimeError: If redis is not used in this backend process.
|
|
|
+ """
|
|
|
+ # Fetch all missing parent states from redis.
|
|
|
+ parent_state_of_state_cls = await self._populate_parent_states(state_cls)
|
|
|
+
|
|
|
+ # Then get the target state and all its substates.
|
|
|
+ state_manager = get_state_manager()
|
|
|
+ if not isinstance(state_manager, StateManagerRedis):
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
|
|
+ "(All states should already be available -- this is likely a bug).",
|
|
|
+ )
|
|
|
+ return await state_manager.get_state(
|
|
|
+ token=_substate_key(self.router.session.client_token, state_cls),
|
|
|
+ top_level=False,
|
|
|
+ get_substates=True,
|
|
|
+ parent_state=parent_state_of_state_cls,
|
|
|
+ )
|
|
|
+
|
|
|
+ async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
|
|
|
+ """Get an instance of the state associated with this token.
|
|
|
+
|
|
|
+ Allows for arbitrary access to sibling states from within an event handler.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state_cls: The class of the state.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The instance of state_cls associated with this state's client_token.
|
|
|
+ """
|
|
|
+ # Fast case - if this state instance is already cached, get_substate from root state.
|
|
|
+ try:
|
|
|
+ return self._get_state_from_cache(state_cls)
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Slow case - fetch missing parent states from redis.
|
|
|
+ return await self._get_state_from_redis(state_cls)
|
|
|
+
|
|
|
def _get_event_handler(
|
|
|
self, event: Event
|
|
|
) -> tuple[BaseState | StateProxy, EventHandler]:
|
|
@@ -1238,6 +1458,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
for cvar in self._computed_var_dependencies[dirty_var]
|
|
|
)
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
|
|
+ """Determine substates which could be affected by dirty vars in this state.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Set of State classes that may need to be fetched to recalc computed vars.
|
|
|
+ """
|
|
|
+ # _always_dirty_substates need to be fetched to recalc computed vars.
|
|
|
+ fetch_substates = set(
|
|
|
+ cls.get_class_substate(tuple(substate_name.split(".")))
|
|
|
+ for substate_name in cls._always_dirty_substates
|
|
|
+ )
|
|
|
+ # Substates with cached vars also need to be fetched.
|
|
|
+ for dependent_substates in cls._substate_var_dependencies.values():
|
|
|
+ fetch_substates.update(
|
|
|
+ set(
|
|
|
+ cls.get_class_substate(tuple(substate_name.split(".")))
|
|
|
+ for substate_name in dependent_substates
|
|
|
+ )
|
|
|
+ )
|
|
|
+ return fetch_substates
|
|
|
+
|
|
|
def get_delta(self) -> Delta:
|
|
|
"""Get the delta for the state.
|
|
|
|
|
@@ -1269,8 +1511,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# Recursively find the substate deltas.
|
|
|
substates = self.substates
|
|
|
for substate in self.dirty_substates.union(self._always_dirty_substates):
|
|
|
- if substate not in substates:
|
|
|
- continue # substate not loaded at this time, no delta
|
|
|
delta.update(substates[substate].get_delta())
|
|
|
|
|
|
# Format the delta.
|
|
@@ -1292,20 +1532,45 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
# have to mark computed vars dirty to allow access to newly computed
|
|
|
# values within the same ComputedVar function
|
|
|
self._mark_dirty_computed_vars()
|
|
|
+ self._mark_dirty_substates()
|
|
|
|
|
|
- # Propagate dirty var / computed var status into substates
|
|
|
+ def _mark_dirty_substates(self):
|
|
|
+ """Propagate dirty var / computed var status into substates."""
|
|
|
substates = self.substates
|
|
|
for var in self.dirty_vars:
|
|
|
for substate_name in self._substate_var_dependencies[var]:
|
|
|
self.dirty_substates.add(substate_name)
|
|
|
- if substate_name not in substates:
|
|
|
- continue
|
|
|
substate = substates[substate_name]
|
|
|
substate.dirty_vars.add(var)
|
|
|
substate._mark_dirty()
|
|
|
|
|
|
+ def _update_was_touched(self):
|
|
|
+ """Update the _was_touched flag based on dirty_vars."""
|
|
|
+ if self.dirty_vars and not self._was_touched:
|
|
|
+ for var in self.dirty_vars:
|
|
|
+ if var in self.base_vars or var in self._backend_vars:
|
|
|
+ self._was_touched = True
|
|
|
+ break
|
|
|
+
|
|
|
+ def _get_was_touched(self) -> bool:
|
|
|
+ """Check current dirty_vars and flag to determine if state instance was modified.
|
|
|
+
|
|
|
+ If any dirty vars belong to this state, mark _was_touched.
|
|
|
+
|
|
|
+ This flag determines whether this state instance should be persisted to redis.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Whether this state instance was ever modified.
|
|
|
+ """
|
|
|
+ # Ensure the flag is up to date based on the current dirty_vars
|
|
|
+ self._update_was_touched()
|
|
|
+ return self._was_touched
|
|
|
+
|
|
|
def _clean(self):
|
|
|
"""Reset the dirty vars."""
|
|
|
+ # Update touched status before cleaning dirty_vars.
|
|
|
+ self._update_was_touched()
|
|
|
+
|
|
|
# Recursively clean the substates.
|
|
|
for substate in self.dirty_substates:
|
|
|
if substate not in self.substates:
|
|
@@ -1422,6 +1687,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
state["__dict__"] = state["__dict__"].copy()
|
|
|
state["__dict__"]["parent_state"] = None
|
|
|
state["__dict__"]["substates"] = {}
|
|
|
+ state["__dict__"].pop("_was_touched", None)
|
|
|
return state
|
|
|
|
|
|
|
|
@@ -1431,6 +1697,35 @@ class State(BaseState):
|
|
|
# The hydrated bool.
|
|
|
is_hydrated: bool = False
|
|
|
|
|
|
+
|
|
|
+class UpdateVarsInternalState(State):
|
|
|
+ """Substate for handling internal state var updates."""
|
|
|
+
|
|
|
+ async def update_vars_internal(self, vars: dict[str, Any]) -> None:
|
|
|
+ """Apply updates to fully qualified state vars.
|
|
|
+
|
|
|
+ The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
|
|
|
+ and each value will be set on the appropriate substate instance.
|
|
|
+
|
|
|
+ This function is primarily used to apply cookie and local storage
|
|
|
+ updates from the frontend to the appropriate substate.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ vars: The fully qualified vars and values to update.
|
|
|
+ """
|
|
|
+ for var, value in vars.items():
|
|
|
+ state_name, _, var_name = var.rpartition(".")
|
|
|
+ var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
|
|
|
+ var_state = await self.get_state(var_state_cls)
|
|
|
+ setattr(var_state, var_name, value)
|
|
|
+
|
|
|
+
|
|
|
+class OnLoadInternalState(State):
|
|
|
+ """Substate for handling on_load event enumeration.
|
|
|
+
|
|
|
+ This is a separate substate to avoid deserializing the entire state tree for every page navigation.
|
|
|
+ """
|
|
|
+
|
|
|
def on_load_internal(self) -> list[Event | EventSpec] | None:
|
|
|
"""Queue on_load handlers for the current page.
|
|
|
|
|
@@ -1442,6 +1737,9 @@ class State(BaseState):
|
|
|
load_events = app.get_load_events(self.router.page.path)
|
|
|
if not load_events and self.is_hydrated:
|
|
|
return # Fast path for page-to-page navigation
|
|
|
+ if not load_events:
|
|
|
+ self.is_hydrated = True
|
|
|
+ return # Fast path for initial hydrate with no on_load events defined.
|
|
|
self.is_hydrated = False
|
|
|
return [
|
|
|
*fix_events(
|
|
@@ -1449,26 +1747,9 @@ class State(BaseState):
|
|
|
self.router.session.client_token,
|
|
|
router_data=self.router_data,
|
|
|
),
|
|
|
- type(self).set_is_hydrated(True), # type: ignore
|
|
|
+ State.set_is_hydrated(True), # type: ignore
|
|
|
]
|
|
|
|
|
|
- def update_vars_internal(self, vars: dict[str, Any]) -> None:
|
|
|
- """Apply updates to fully qualified state vars.
|
|
|
-
|
|
|
- The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
|
|
|
- and each value will be set on the appropriate substate instance.
|
|
|
-
|
|
|
- This function is primarily used to apply cookie and local storage
|
|
|
- updates from the frontend to the appropriate substate.
|
|
|
-
|
|
|
- Args:
|
|
|
- vars: The fully qualified vars and values to update.
|
|
|
- """
|
|
|
- for var, value in vars.items():
|
|
|
- state_name, _, var_name = var.rpartition(".")
|
|
|
- var_state = self.get_substate(state_name.split("."))
|
|
|
- setattr(var_state, var_name, value)
|
|
|
-
|
|
|
|
|
|
class StateProxy(wrapt.ObjectProxy):
|
|
|
"""Proxy of a state instance to control mutability of vars for a background task.
|
|
@@ -1522,9 +1803,10 @@ class StateProxy(wrapt.ObjectProxy):
|
|
|
This StateProxy instance in mutable mode.
|
|
|
"""
|
|
|
self._self_actx = self._self_app.modify_state(
|
|
|
- self.__wrapped__.router.session.client_token
|
|
|
- + "_"
|
|
|
- + ".".join(self._self_substate_path)
|
|
|
+ token=_substate_key(
|
|
|
+ self.__wrapped__.router.session.client_token,
|
|
|
+ self._self_substate_path,
|
|
|
+ )
|
|
|
)
|
|
|
mutable_state = await self._self_actx.__aenter__()
|
|
|
super().__setattr__(
|
|
@@ -1574,7 +1856,15 @@ class StateProxy(wrapt.ObjectProxy):
|
|
|
|
|
|
Returns:
|
|
|
The value of the attribute.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ImmutableStateError: If the state is not in mutable mode.
|
|
|
"""
|
|
|
+ if name in ["substates", "parent_state"] and not self._self_mutable:
|
|
|
+ raise ImmutableStateError(
|
|
|
+ "Background task StateProxy is immutable outside of a context "
|
|
|
+ "manager. Use `async with self` to modify state."
|
|
|
+ )
|
|
|
value = super().__getattr__(name)
|
|
|
if not name.startswith("_self_") and isinstance(value, MutableProxy):
|
|
|
# ensure mutations to these containers are blocked unless proxy is _mutable
|
|
@@ -1622,6 +1912,60 @@ class StateProxy(wrapt.ObjectProxy):
|
|
|
"manager. Use `async with self` to modify state."
|
|
|
)
|
|
|
|
|
|
+ def get_substate(self, path: Sequence[str]) -> BaseState:
|
|
|
+ """Only allow substate access with lock held.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ path: The path to the substate.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The substate.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ImmutableStateError: If the state is not in mutable mode.
|
|
|
+ """
|
|
|
+ if not self._self_mutable:
|
|
|
+ raise ImmutableStateError(
|
|
|
+ "Background task StateProxy is immutable outside of a context "
|
|
|
+ "manager. Use `async with self` to modify state."
|
|
|
+ )
|
|
|
+ return self.__wrapped__.get_substate(path)
|
|
|
+
|
|
|
+ async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
|
|
|
+ """Get an instance of the state associated with this token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state_cls: The class of the state.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ImmutableStateError: If the state is not in mutable mode.
|
|
|
+ """
|
|
|
+ if not self._self_mutable:
|
|
|
+ raise ImmutableStateError(
|
|
|
+ "Background task StateProxy is immutable outside of a context "
|
|
|
+ "manager. Use `async with self` to modify state."
|
|
|
+ )
|
|
|
+ return await self.__wrapped__.get_state(state_cls)
|
|
|
+
|
|
|
+ def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
|
|
+ """Temporarily allow mutability to access parent_state.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ *args: The args to pass to the underlying state instance.
|
|
|
+ **kwargs: The kwargs to pass to the underlying state instance.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state update.
|
|
|
+ """
|
|
|
+ self._self_mutable = True
|
|
|
+ try:
|
|
|
+ return self.__wrapped__._as_state_update(*args, **kwargs)
|
|
|
+ finally:
|
|
|
+ self._self_mutable = False
|
|
|
+
|
|
|
|
|
|
class StateUpdate(Base):
|
|
|
"""A state update sent to the frontend."""
|
|
@@ -1722,9 +2066,9 @@ class StateManagerMemory(StateManager):
|
|
|
The state for the token.
|
|
|
"""
|
|
|
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
- token = token.partition("_")[0]
|
|
|
+ token = _split_substate_key(token)[0]
|
|
|
if token not in self.states:
|
|
|
- self.states[token] = self.state()
|
|
|
+ self.states[token] = self.state(_reflex_internal_init=True)
|
|
|
return self.states[token]
|
|
|
|
|
|
async def set_state(self, token: str, state: BaseState):
|
|
@@ -1747,7 +2091,7 @@ class StateManagerMemory(StateManager):
|
|
|
The state for the token.
|
|
|
"""
|
|
|
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
- token = token.partition("_")[0]
|
|
|
+ token = _split_substate_key(token)[0]
|
|
|
if token not in self._states_locks:
|
|
|
async with self._state_manager_lock:
|
|
|
if token not in self._states_locks:
|
|
@@ -1787,6 +2131,81 @@ class StateManagerRedis(StateManager):
|
|
|
b"evicted",
|
|
|
}
|
|
|
|
|
|
+ def _get_root_state(self, state: BaseState) -> BaseState:
|
|
|
+ """Chase parent_state pointers to find an instance of the top-level state.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state: The state to start from.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ An instance of the top-level state (self.state).
|
|
|
+ """
|
|
|
+ while type(state) != self.state and state.parent_state is not None:
|
|
|
+ state = state.parent_state
|
|
|
+ return state
|
|
|
+
|
|
|
+ async def _get_parent_state(self, token: str) -> BaseState | None:
|
|
|
+ """Get the parent state for the state requested in the token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the state for (_substate_key).
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The parent state for the state requested by the token or None if there is no such parent.
|
|
|
+ """
|
|
|
+ parent_state = None
|
|
|
+ client_token, state_path = _split_substate_key(token)
|
|
|
+ parent_state_name = state_path.rpartition(".")[0]
|
|
|
+ if parent_state_name:
|
|
|
+ # Retrieve the parent state to populate event handlers onto this substate.
|
|
|
+ parent_state = await self.get_state(
|
|
|
+ token=_substate_key(client_token, parent_state_name),
|
|
|
+ top_level=False,
|
|
|
+ get_substates=False,
|
|
|
+ )
|
|
|
+ return parent_state
|
|
|
+
|
|
|
+ async def _populate_substates(
|
|
|
+ self,
|
|
|
+ token: str,
|
|
|
+ state: BaseState,
|
|
|
+ all_substates: bool = False,
|
|
|
+ ):
|
|
|
+ """Fetch and link substates for the given state instance.
|
|
|
+
|
|
|
+ There is no return value; the side-effect is that `state` will have `substates` populated,
|
|
|
+ and each substate will have its `parent_state` set to `state`.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the state for.
|
|
|
+ state: The state instance to populate substates for.
|
|
|
+ all_substates: Whether to fetch all substates or just required substates.
|
|
|
+ """
|
|
|
+ client_token, _ = _split_substate_key(token)
|
|
|
+
|
|
|
+ if all_substates:
|
|
|
+ # All substates are requested.
|
|
|
+ fetch_substates = state.get_substates()
|
|
|
+ else:
|
|
|
+ # Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
|
|
+ fetch_substates = state._potentially_dirty_substates()
|
|
|
+
|
|
|
+ tasks = {}
|
|
|
+ # Retrieve the necessary substates from redis.
|
|
|
+ for substate_cls in fetch_substates:
|
|
|
+ substate_name = substate_cls.get_name()
|
|
|
+ tasks[substate_name] = asyncio.create_task(
|
|
|
+ self.get_state(
|
|
|
+ token=_substate_key(client_token, substate_cls),
|
|
|
+ top_level=False,
|
|
|
+ get_substates=all_substates,
|
|
|
+ parent_state=state,
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ for substate_name, substate_task in tasks.items():
|
|
|
+ state.substates[substate_name] = await substate_task
|
|
|
+
|
|
|
async def get_state(
|
|
|
self,
|
|
|
token: str,
|
|
@@ -1798,8 +2217,8 @@ class StateManagerRedis(StateManager):
|
|
|
|
|
|
Args:
|
|
|
token: The token to get the state for.
|
|
|
- top_level: If true, return an instance of the top-level state.
|
|
|
- get_substates: If true, also retrieve substates
|
|
|
+ top_level: If true, return an instance of the top-level state (self.state).
|
|
|
+ get_substates: If true, also retrieve substates.
|
|
|
parent_state: If provided, use this parent_state instead of getting it from redis.
|
|
|
|
|
|
Returns:
|
|
@@ -1809,7 +2228,7 @@ class StateManagerRedis(StateManager):
|
|
|
RuntimeError: when the state_cls is not specified in the token
|
|
|
"""
|
|
|
# Split the actual token from the fully qualified substate name.
|
|
|
- client_token, _, state_path = token.partition("_")
|
|
|
+ _, state_path = _split_substate_key(token)
|
|
|
if state_path:
|
|
|
# Get the State class associated with the given path.
|
|
|
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
|
|
@@ -1825,66 +2244,49 @@ class StateManagerRedis(StateManager):
|
|
|
# Deserialize the substate.
|
|
|
state = cloudpickle.loads(redis_state)
|
|
|
|
|
|
- # Populate parent and substates if requested.
|
|
|
+ # Populate parent state if missing and requested.
|
|
|
if parent_state is None:
|
|
|
- # Retrieve the parent state from redis.
|
|
|
- parent_state_name = state_path.rpartition(".")[0]
|
|
|
- if parent_state_name:
|
|
|
- parent_state_key = token.rpartition(".")[0]
|
|
|
- parent_state = await self.get_state(
|
|
|
- parent_state_key, top_level=False, get_substates=False
|
|
|
- )
|
|
|
+ parent_state = await self._get_parent_state(token)
|
|
|
# Set up Bidirectional linkage between this state and its parent.
|
|
|
if parent_state is not None:
|
|
|
parent_state.substates[state.get_name()] = state
|
|
|
state.parent_state = parent_state
|
|
|
- if get_substates:
|
|
|
- # Retrieve all substates from redis.
|
|
|
- for substate_cls in state_cls.get_substates():
|
|
|
- substate_name = substate_cls.get_name()
|
|
|
- substate_key = token + "." + substate_name
|
|
|
- state.substates[substate_name] = await self.get_state(
|
|
|
- substate_key, top_level=False, parent_state=state
|
|
|
- )
|
|
|
+ # Populate substates if requested.
|
|
|
+ await self._populate_substates(token, state, all_substates=get_substates)
|
|
|
+
|
|
|
# To retain compatibility with previous implementation, by default, we return
|
|
|
# the top-level state by chasing `parent_state` pointers up the tree.
|
|
|
if top_level:
|
|
|
- while type(state) != self.state and state.parent_state is not None:
|
|
|
- state = state.parent_state
|
|
|
+ return self._get_root_state(state)
|
|
|
return state
|
|
|
|
|
|
- # Key didn't exist so we have to create a new entry for this token.
|
|
|
+ # TODO: dedupe the following logic with the above block
|
|
|
+ # Key didn't exist so we have to create a new instance for this token.
|
|
|
if parent_state is None:
|
|
|
- parent_state_name = state_path.rpartition(".")[0]
|
|
|
- if parent_state_name:
|
|
|
- # Retrieve the parent state to populate event handlers onto this substate.
|
|
|
- parent_state_key = client_token + "_" + parent_state_name
|
|
|
- parent_state = await self.get_state(
|
|
|
- parent_state_key, top_level=False, get_substates=False
|
|
|
- )
|
|
|
- # Persist the new state class to redis.
|
|
|
- await self.set_state(
|
|
|
- token,
|
|
|
- state_cls(
|
|
|
- parent_state=parent_state,
|
|
|
- init_substates=False,
|
|
|
- ),
|
|
|
- )
|
|
|
- # After creating the state key, recursively call `get_state` to populate substates.
|
|
|
- return await self.get_state(
|
|
|
- token,
|
|
|
- top_level=top_level,
|
|
|
- get_substates=get_substates,
|
|
|
+ parent_state = await self._get_parent_state(token)
|
|
|
+ # Instantiate the new state class (but don't persist it yet).
|
|
|
+ state = state_cls(
|
|
|
parent_state=parent_state,
|
|
|
+ init_substates=False,
|
|
|
+ _reflex_internal_init=True,
|
|
|
)
|
|
|
+ # Set up Bidirectional linkage between this state and its parent.
|
|
|
+ if parent_state is not None:
|
|
|
+ parent_state.substates[state.get_name()] = state
|
|
|
+ state.parent_state = parent_state
|
|
|
+ # Populate substates for the newly created state.
|
|
|
+ await self._populate_substates(token, state, all_substates=get_substates)
|
|
|
+ # To retain compatibility with previous implementation, by default, we return
|
|
|
+ # the top-level state by chasing `parent_state` pointers up the tree.
|
|
|
+ if top_level:
|
|
|
+ return self._get_root_state(state)
|
|
|
+ return state
|
|
|
|
|
|
async def set_state(
|
|
|
self,
|
|
|
token: str,
|
|
|
state: BaseState,
|
|
|
lock_id: bytes | None = None,
|
|
|
- set_substates: bool = True,
|
|
|
- set_parent_state: bool = True,
|
|
|
):
|
|
|
"""Set the state for a token.
|
|
|
|
|
@@ -1892,11 +2294,10 @@ class StateManagerRedis(StateManager):
|
|
|
token: The token to set the state for.
|
|
|
state: The state to set.
|
|
|
lock_id: If provided, the lock_key must be set to this value to set the state.
|
|
|
- set_substates: If True, write substates to redis
|
|
|
- set_parent_state: If True, write parent state to redis
|
|
|
|
|
|
Raises:
|
|
|
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
|
|
|
+ RuntimeError: If the state instance doesn't match the state name in the token.
|
|
|
"""
|
|
|
# Check that we're holding the lock.
|
|
|
if (
|
|
@@ -1908,28 +2309,36 @@ class StateManagerRedis(StateManager):
|
|
|
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
|
|
"or use `@rx.background` decorator for long-running tasks."
|
|
|
)
|
|
|
- # Find the substate associated with the token.
|
|
|
- state_path = token.partition("_")[2]
|
|
|
- if state_path and state.get_full_name() != state_path:
|
|
|
- state = state.get_substate(tuple(state_path.split(".")))
|
|
|
- # Persist the parent state separately, if requested.
|
|
|
- if state.parent_state is not None and set_parent_state:
|
|
|
- parent_state_key = token.rpartition(".")[0]
|
|
|
- await self.set_state(
|
|
|
- parent_state_key,
|
|
|
- state.parent_state,
|
|
|
- lock_id=lock_id,
|
|
|
- set_substates=False,
|
|
|
+ client_token, substate_name = _split_substate_key(token)
|
|
|
+ # If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
|
|
+ if state.parent_state is not None and state.get_full_name() != substate_name:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
|
|
|
)
|
|
|
- # Persist the substates separately, if requested.
|
|
|
- if set_substates:
|
|
|
- for substate_name, substate in state.substates.items():
|
|
|
- substate_key = token + "." + substate_name
|
|
|
- await self.set_state(
|
|
|
- substate_key, substate, lock_id=lock_id, set_parent_state=False
|
|
|
+
|
|
|
+ # Recursively set_state on all known substates.
|
|
|
+ tasks = []
|
|
|
+ for substate in state.substates.values():
|
|
|
+ tasks.append(
|
|
|
+ asyncio.create_task(
|
|
|
+ self.set_state(
|
|
|
+ token=_substate_key(client_token, substate),
|
|
|
+ state=substate,
|
|
|
+ lock_id=lock_id,
|
|
|
+ )
|
|
|
)
|
|
|
+ )
|
|
|
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
|
|
- await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
|
|
|
+ if state._get_was_touched():
|
|
|
+ await self.redis.set(
|
|
|
+ _substate_key(client_token, state),
|
|
|
+ cloudpickle.dumps(state),
|
|
|
+ ex=self.token_expiration,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Wait for substates to be persisted.
|
|
|
+ for t in tasks:
|
|
|
+ await t
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
@@ -1957,7 +2366,7 @@ class StateManagerRedis(StateManager):
|
|
|
The redis lock key for the token.
|
|
|
"""
|
|
|
# All substates share the same lock domain, so ignore any substate path suffix.
|
|
|
- client_token = token.partition("_")[0]
|
|
|
+ client_token = _split_substate_key(token)[0]
|
|
|
return f"{client_token}_lock".encode()
|
|
|
|
|
|
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
|
|
@@ -2052,6 +2461,16 @@ class StateManagerRedis(StateManager):
|
|
|
await self.redis.close(close_connection_pool=True)
|
|
|
|
|
|
|
|
|
+def get_state_manager() -> StateManager:
|
|
|
+ """Get the state manager for the app that is currently running.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state manager.
|
|
|
+ """
|
|
|
+ app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
|
|
|
+ return app.state_manager
|
|
|
+
|
|
|
+
|
|
|
class ClientStorageBase:
|
|
|
"""Base class for client-side storage."""
|
|
|
|