|
@@ -0,0 +1,858 @@
|
|
|
+"""State manager for managing client states."""
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import contextlib
|
|
|
+import dataclasses
|
|
|
+import functools
|
|
|
+import time
|
|
|
+import uuid
|
|
|
+from abc import ABC, abstractmethod
|
|
|
+from collections.abc import AsyncIterator
|
|
|
+from hashlib import md5
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+from redis import ResponseError
|
|
|
+from redis.asyncio import Redis
|
|
|
+from redis.asyncio.client import PubSub
|
|
|
+from typing_extensions import override
|
|
|
+
|
|
|
+from reflex import constants
|
|
|
+from reflex.config import environment, get_config
|
|
|
+from reflex.state import BaseState, _split_substate_key, _substate_key
|
|
|
+from reflex.utils import console, path_ops, prerequisites
|
|
|
+from reflex.utils.exceptions import (
|
|
|
+ InvalidLockWarningThresholdError,
|
|
|
+ InvalidStateManagerModeError,
|
|
|
+ LockExpiredError,
|
|
|
+ StateSchemaMismatchError,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+@dataclasses.dataclass
|
|
|
+class StateManager(ABC):
|
|
|
+ """A class to manage many client states."""
|
|
|
+
|
|
|
+ # The state class to use.
|
|
|
+ state: type[BaseState]
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def create(cls, state: type[BaseState]):
|
|
|
+ """Create a new state manager.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state: The state class to use.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ InvalidStateManagerModeError: If the state manager mode is invalid.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state manager (either disk, memory or redis).
|
|
|
+ """
|
|
|
+ config = get_config()
|
|
|
+ if prerequisites.parse_redis_url() is not None:
|
|
|
+ config.state_manager_mode = constants.StateManagerMode.REDIS
|
|
|
+ if config.state_manager_mode == constants.StateManagerMode.MEMORY:
|
|
|
+ return StateManagerMemory(state=state)
|
|
|
+ if config.state_manager_mode == constants.StateManagerMode.DISK:
|
|
|
+ return StateManagerDisk(state=state)
|
|
|
+ if config.state_manager_mode == constants.StateManagerMode.REDIS:
|
|
|
+ redis = prerequisites.get_redis()
|
|
|
+ if redis is not None:
|
|
|
+ # make sure expiration values are obtained only from the config object on creation
|
|
|
+ return StateManagerRedis(
|
|
|
+ state=state,
|
|
|
+ redis=redis,
|
|
|
+ token_expiration=config.redis_token_expiration,
|
|
|
+ lock_expiration=config.redis_lock_expiration,
|
|
|
+ lock_warning_threshold=config.redis_lock_warning_threshold,
|
|
|
+ )
|
|
|
+ raise InvalidStateManagerModeError(
|
|
|
+ f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
|
|
+ )
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ async def get_state(self, token: str) -> BaseState:
|
|
|
+ """Get the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the state for.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state for the token.
|
|
|
+ """
|
|
|
+ pass
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ async def set_state(self, token: str, state: BaseState):
|
|
|
+ """Set the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to set the state for.
|
|
|
+ state: The state to set.
|
|
|
+ """
|
|
|
+ pass
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ @contextlib.asynccontextmanager
|
|
|
+ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
|
+ """Modify the state for a token while holding exclusive lock.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to modify the state for.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ The state for the token.
|
|
|
+ """
|
|
|
+ yield self.state()
|
|
|
+
|
|
|
+
|
|
|
+@dataclasses.dataclass
|
|
|
+class StateManagerMemory(StateManager):
|
|
|
+ """A state manager that stores states in memory."""
|
|
|
+
|
|
|
+ # The mapping of client ids to states.
|
|
|
+ states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
|
|
|
+
|
|
|
+ # The mutex ensures the dict of mutexes is updated exclusively
|
|
|
+ _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
|
|
|
+
|
|
|
+ # The dict of mutexes for each client
|
|
|
+ _states_locks: dict[str, asyncio.Lock] = dataclasses.field(
|
|
|
+ default_factory=dict, init=False
|
|
|
+ )
|
|
|
+
|
|
|
+ @override
|
|
|
+ async def get_state(self, token: str) -> BaseState:
|
|
|
+ """Get the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the state for.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state for the token.
|
|
|
+ """
|
|
|
+ # Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
+ token = _split_substate_key(token)[0]
|
|
|
+ if token not in self.states:
|
|
|
+ self.states[token] = self.state(_reflex_internal_init=True)
|
|
|
+ return self.states[token]
|
|
|
+
|
|
|
+ @override
|
|
|
+ async def set_state(self, token: str, state: BaseState):
|
|
|
+ """Set the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to set the state for.
|
|
|
+ state: The state to set.
|
|
|
+ """
|
|
|
+ pass
|
|
|
+
|
|
|
+ @override
|
|
|
+ @contextlib.asynccontextmanager
|
|
|
+ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
|
+ """Modify the state for a token while holding exclusive lock.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to modify the state for.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ The state for the token.
|
|
|
+ """
|
|
|
+ # Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
+ token = _split_substate_key(token)[0]
|
|
|
+ if token not in self._states_locks:
|
|
|
+ async with self._state_manager_lock:
|
|
|
+ if token not in self._states_locks:
|
|
|
+ self._states_locks[token] = asyncio.Lock()
|
|
|
+
|
|
|
+ async with self._states_locks[token]:
|
|
|
+ state = await self.get_state(token)
|
|
|
+ yield state
|
|
|
+ await self.set_state(token, state)
|
|
|
+
|
|
|
+
|
|
|
+def _default_token_expiration() -> int:
|
|
|
+ """Get the default token expiration time.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The default token expiration time.
|
|
|
+ """
|
|
|
+ return get_config().redis_token_expiration
|
|
|
+
|
|
|
+
|
|
|
+def reset_disk_state_manager():
|
|
|
+ """Reset the disk state manager."""
|
|
|
+ states_directory = prerequisites.get_states_dir()
|
|
|
+ if states_directory.exists():
|
|
|
+ for path in states_directory.iterdir():
|
|
|
+ path.unlink()
|
|
|
+
|
|
|
+
|
|
|
+@dataclasses.dataclass
|
|
|
+class StateManagerDisk(StateManager):
|
|
|
+ """A state manager that stores states in memory."""
|
|
|
+
|
|
|
+ # The mapping of client ids to states.
|
|
|
+ states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
|
|
|
+
|
|
|
+ # The mutex ensures the dict of mutexes is updated exclusively
|
|
|
+ _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
|
|
|
+
|
|
|
+ # The dict of mutexes for each client
|
|
|
+ _states_locks: dict[str, asyncio.Lock] = dataclasses.field(
|
|
|
+ default_factory=dict,
|
|
|
+ init=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ # The token expiration time (s).
|
|
|
+ token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
|
|
|
+
|
|
|
+ def __post_init_(self):
|
|
|
+ """Create a new state manager."""
|
|
|
+ path_ops.mkdir(self.states_directory)
|
|
|
+
|
|
|
+ self._purge_expired_states()
|
|
|
+
|
|
|
+ @functools.cached_property
|
|
|
+ def states_directory(self) -> Path:
|
|
|
+ """Get the states directory.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The states directory.
|
|
|
+ """
|
|
|
+ return prerequisites.get_states_dir()
|
|
|
+
|
|
|
+ def _purge_expired_states(self):
|
|
|
+ """Purge expired states from the disk."""
|
|
|
+ import time
|
|
|
+
|
|
|
+ for path in path_ops.ls(self.states_directory):
|
|
|
+ # check path is a pickle file
|
|
|
+ if path.suffix != ".pkl":
|
|
|
+ continue
|
|
|
+
|
|
|
+ # load last edited field from file
|
|
|
+ last_edited = path.stat().st_mtime
|
|
|
+
|
|
|
+ # check if the file is older than the token expiration time
|
|
|
+ if time.time() - last_edited > self.token_expiration:
|
|
|
+ # remove the file
|
|
|
+ path.unlink()
|
|
|
+
|
|
|
+ def token_path(self, token: str) -> Path:
|
|
|
+ """Get the path for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the path for.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The path for the token.
|
|
|
+ """
|
|
|
+ return (
|
|
|
+ self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
|
|
|
+ ).absolute()
|
|
|
+
|
|
|
+ async def load_state(self, token: str) -> BaseState | None:
|
|
|
+ """Load a state object based on the provided token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token used to identify the state object.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The loaded state object or None.
|
|
|
+ """
|
|
|
+ token_path = self.token_path(token)
|
|
|
+
|
|
|
+ if token_path.exists():
|
|
|
+ try:
|
|
|
+ with token_path.open(mode="rb") as file:
|
|
|
+ return BaseState._deserialize(fp=file)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ async def populate_substates(
|
|
|
+ self, client_token: str, state: BaseState, root_state: BaseState
|
|
|
+ ):
|
|
|
+ """Populate the substates of a state object.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ client_token: The client token.
|
|
|
+ state: The state object to populate.
|
|
|
+ root_state: The root state object.
|
|
|
+ """
|
|
|
+ for substate in state.get_substates():
|
|
|
+ substate_token = _substate_key(client_token, substate)
|
|
|
+
|
|
|
+ fresh_instance = await root_state.get_state(substate)
|
|
|
+ instance = await self.load_state(substate_token)
|
|
|
+ if instance is not None:
|
|
|
+ # Ensure all substates exist, even if they weren't serialized previously.
|
|
|
+ instance.substates = fresh_instance.substates
|
|
|
+ else:
|
|
|
+ instance = fresh_instance
|
|
|
+ state.substates[substate.get_name()] = instance
|
|
|
+ instance.parent_state = state
|
|
|
+
|
|
|
+ await self.populate_substates(client_token, instance, root_state)
|
|
|
+
|
|
|
+ @override
|
|
|
+ async def get_state(
|
|
|
+ self,
|
|
|
+ token: str,
|
|
|
+ ) -> BaseState:
|
|
|
+ """Get the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the state for.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state for the token.
|
|
|
+ """
|
|
|
+ client_token = _split_substate_key(token)[0]
|
|
|
+ root_state = self.states.get(client_token)
|
|
|
+ if root_state is not None:
|
|
|
+ # Retrieved state from memory.
|
|
|
+ return root_state
|
|
|
+
|
|
|
+ # Deserialize root state from disk.
|
|
|
+ root_state = await self.load_state(_substate_key(client_token, self.state))
|
|
|
+ # Create a new root state tree with all substates instantiated.
|
|
|
+ fresh_root_state = self.state(_reflex_internal_init=True)
|
|
|
+ if root_state is None:
|
|
|
+ root_state = fresh_root_state
|
|
|
+ else:
|
|
|
+ # Ensure all substates exist, even if they were not serialized previously.
|
|
|
+ root_state.substates = fresh_root_state.substates
|
|
|
+ self.states[client_token] = root_state
|
|
|
+ await self.populate_substates(client_token, root_state, root_state)
|
|
|
+ return root_state
|
|
|
+
|
|
|
+ async def set_state_for_substate(self, client_token: str, substate: BaseState):
|
|
|
+ """Set the state for a substate.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ client_token: The client token.
|
|
|
+ substate: The substate to set.
|
|
|
+ """
|
|
|
+ substate_token = _substate_key(client_token, substate)
|
|
|
+
|
|
|
+ if substate._get_was_touched():
|
|
|
+ substate._was_touched = False # Reset the touched flag after serializing.
|
|
|
+ pickle_state = substate._serialize()
|
|
|
+ if pickle_state:
|
|
|
+ if not self.states_directory.exists():
|
|
|
+ self.states_directory.mkdir(parents=True, exist_ok=True)
|
|
|
+ self.token_path(substate_token).write_bytes(pickle_state)
|
|
|
+
|
|
|
+ for substate_substate in substate.substates.values():
|
|
|
+ await self.set_state_for_substate(client_token, substate_substate)
|
|
|
+
|
|
|
+ @override
|
|
|
+ async def set_state(self, token: str, state: BaseState):
|
|
|
+ """Set the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to set the state for.
|
|
|
+ state: The state to set.
|
|
|
+ """
|
|
|
+ client_token, substate = _split_substate_key(token)
|
|
|
+ await self.set_state_for_substate(client_token, state)
|
|
|
+
|
|
|
+ @override
|
|
|
+ @contextlib.asynccontextmanager
|
|
|
+ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
|
+ """Modify the state for a token while holding exclusive lock.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to modify the state for.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ The state for the token.
|
|
|
+ """
|
|
|
+ # Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
+ client_token, substate = _split_substate_key(token)
|
|
|
+ if client_token not in self._states_locks:
|
|
|
+ async with self._state_manager_lock:
|
|
|
+ if client_token not in self._states_locks:
|
|
|
+ self._states_locks[client_token] = asyncio.Lock()
|
|
|
+
|
|
|
+ async with self._states_locks[client_token]:
|
|
|
+ state = await self.get_state(token)
|
|
|
+ yield state
|
|
|
+ await self.set_state(token, state)
|
|
|
+
|
|
|
+
|
|
|
+def _default_lock_expiration() -> int:
|
|
|
+ """Get the default lock expiration time.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The default lock expiration time.
|
|
|
+ """
|
|
|
+ return get_config().redis_lock_expiration
|
|
|
+
|
|
|
+
|
|
|
+def _default_lock_warning_threshold() -> int:
|
|
|
+ """Get the default lock warning threshold.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The default lock warning threshold.
|
|
|
+ """
|
|
|
+ return get_config().redis_lock_warning_threshold
|
|
|
+
|
|
|
+
|
|
|
+@dataclasses.dataclass
|
|
|
+class StateManagerRedis(StateManager):
|
|
|
+ """A state manager that stores states in redis."""
|
|
|
+
|
|
|
+ # The redis client to use.
|
|
|
+ redis: Redis
|
|
|
+
|
|
|
+ # The token expiration time (s).
|
|
|
+ token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
|
|
|
+
|
|
|
+ # The maximum time to hold a lock (ms).
|
|
|
+ lock_expiration: int = dataclasses.field(default_factory=_default_lock_expiration)
|
|
|
+
|
|
|
+ # The maximum time to hold a lock (ms) before warning.
|
|
|
+ lock_warning_threshold: int = dataclasses.field(
|
|
|
+ default_factory=_default_lock_warning_threshold
|
|
|
+ )
|
|
|
+
|
|
|
+ # The keyspace subscription string when redis is waiting for lock to be released.
|
|
|
+ _redis_notify_keyspace_events: str = dataclasses.field(
|
|
|
+ default="K" # Enable keyspace notifications (target a particular key)
|
|
|
+ "g" # For generic commands (DEL, EXPIRE, etc)
|
|
|
+ "x" # For expired events
|
|
|
+ "e" # For evicted events (i.e. maxmemory exceeded)
|
|
|
+ )
|
|
|
+
|
|
|
+ # These events indicate that a lock is no longer held.
|
|
|
+ _redis_keyspace_lock_release_events: set[bytes] = dataclasses.field(
|
|
|
+ default_factory=lambda: {
|
|
|
+ b"del",
|
|
|
+ b"expire",
|
|
|
+ b"expired",
|
|
|
+ b"evicted",
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ # Whether keyspace notifications have been enabled.
|
|
|
+ _redis_notify_keyspace_events_enabled: bool = dataclasses.field(default=False)
|
|
|
+
|
|
|
+ # The logical database number used by the redis client.
|
|
|
+ _redis_db: int = dataclasses.field(default=0)
|
|
|
+
|
|
|
+ def __post_init__(self):
|
|
|
+ """Validate the lock warning threshold.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
|
|
|
+ """
|
|
|
+ if self.lock_warning_threshold >= (lock_expiration := self.lock_expiration):
|
|
|
+ raise InvalidLockWarningThresholdError(
|
|
|
+ f"The lock warning threshold({self.lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
|
|
+ )
|
|
|
+
|
|
|
+ def _get_required_state_classes(
|
|
|
+ self,
|
|
|
+ target_state_cls: type[BaseState],
|
|
|
+ subclasses: bool = False,
|
|
|
+ required_state_classes: set[type[BaseState]] | None = None,
|
|
|
+ ) -> set[type[BaseState]]:
|
|
|
+ """Recursively determine which states are required to fetch the target state.
|
|
|
+
|
|
|
+ This will always include potentially dirty substates that depend on vars
|
|
|
+ in the target_state_cls.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ target_state_cls: The target state class being fetched.
|
|
|
+ subclasses: Whether to include subclasses of the target state.
|
|
|
+ required_state_classes: Recursive argument tracking state classes that have already been seen.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The set of state classes required to fetch the target state.
|
|
|
+ """
|
|
|
+ if required_state_classes is None:
|
|
|
+ required_state_classes = set()
|
|
|
+ # Get the substates if requested.
|
|
|
+ if subclasses:
|
|
|
+ for substate in target_state_cls.get_substates():
|
|
|
+ self._get_required_state_classes(
|
|
|
+ substate,
|
|
|
+ subclasses=True,
|
|
|
+ required_state_classes=required_state_classes,
|
|
|
+ )
|
|
|
+ if target_state_cls in required_state_classes:
|
|
|
+ return required_state_classes
|
|
|
+ required_state_classes.add(target_state_cls)
|
|
|
+
|
|
|
+ # Get dependent substates.
|
|
|
+ for pd_substates in target_state_cls._get_potentially_dirty_states():
|
|
|
+ self._get_required_state_classes(
|
|
|
+ pd_substates,
|
|
|
+ subclasses=False,
|
|
|
+ required_state_classes=required_state_classes,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Get the parent state if it exists.
|
|
|
+ if parent_state := target_state_cls.get_parent_state():
|
|
|
+ self._get_required_state_classes(
|
|
|
+ parent_state,
|
|
|
+ subclasses=False,
|
|
|
+ required_state_classes=required_state_classes,
|
|
|
+ )
|
|
|
+ return required_state_classes
|
|
|
+
|
|
|
+ def _get_populated_states(
|
|
|
+ self,
|
|
|
+ target_state: BaseState,
|
|
|
+ populated_states: dict[str, BaseState] | None = None,
|
|
|
+ ) -> dict[str, BaseState]:
|
|
|
+ """Recursively determine which states from target_state are already fetched.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ target_state: The state to check for populated states.
|
|
|
+ populated_states: Recursive argument tracking states seen in previous calls.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ A dictionary of state full name to state instance.
|
|
|
+ """
|
|
|
+ if populated_states is None:
|
|
|
+ populated_states = {}
|
|
|
+ if target_state.get_full_name() in populated_states:
|
|
|
+ return populated_states
|
|
|
+ populated_states[target_state.get_full_name()] = target_state
|
|
|
+ for substate in target_state.substates.values():
|
|
|
+ self._get_populated_states(substate, populated_states=populated_states)
|
|
|
+ if target_state.parent_state is not None:
|
|
|
+ self._get_populated_states(
|
|
|
+ target_state.parent_state, populated_states=populated_states
|
|
|
+ )
|
|
|
+ return populated_states
|
|
|
+
|
|
|
+ @override
|
|
|
+ async def get_state(
|
|
|
+ self,
|
|
|
+ token: str,
|
|
|
+ top_level: bool = True,
|
|
|
+ for_state_instance: BaseState | None = None,
|
|
|
+ ) -> BaseState:
|
|
|
+ """Get the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the state for.
|
|
|
+ top_level: If true, return an instance of the top-level state (self.state).
|
|
|
+ for_state_instance: If provided, attach the requested states to this existing state tree.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state for the token.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
|
|
|
+ requested state was not fetched.
|
|
|
+ """
|
|
|
+ # Split the actual token from the fully qualified substate name.
|
|
|
+ token, state_path = _split_substate_key(token)
|
|
|
+ if state_path:
|
|
|
+ # Get the State class associated with the given path.
|
|
|
+ state_cls = self.state.get_class_substate(state_path)
|
|
|
+ else:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Determine which states we already have.
|
|
|
+ flat_state_tree: dict[str, BaseState] = (
|
|
|
+ self._get_populated_states(for_state_instance) if for_state_instance else {}
|
|
|
+ )
|
|
|
+
|
|
|
+ # Determine which states from the tree need to be fetched.
|
|
|
+ required_state_classes = sorted(
|
|
|
+ self._get_required_state_classes(state_cls, subclasses=True)
|
|
|
+ - {type(s) for s in flat_state_tree.values()},
|
|
|
+ key=lambda x: x.get_full_name(),
|
|
|
+ )
|
|
|
+
|
|
|
+ redis_pipeline = self.redis.pipeline()
|
|
|
+ for state_cls in required_state_classes:
|
|
|
+ redis_pipeline.get(_substate_key(token, state_cls))
|
|
|
+
|
|
|
+ for state_cls, redis_state in zip(
|
|
|
+ required_state_classes,
|
|
|
+ await redis_pipeline.execute(),
|
|
|
+ strict=False,
|
|
|
+ ):
|
|
|
+ state = None
|
|
|
+
|
|
|
+ if redis_state is not None:
|
|
|
+ # Deserialize the substate.
|
|
|
+ with contextlib.suppress(StateSchemaMismatchError):
|
|
|
+ state = BaseState._deserialize(data=redis_state)
|
|
|
+ if state is None:
|
|
|
+ # Key didn't exist or schema mismatch so create a new instance for this token.
|
|
|
+ state = state_cls(
|
|
|
+ init_substates=False,
|
|
|
+ _reflex_internal_init=True,
|
|
|
+ )
|
|
|
+ flat_state_tree[state.get_full_name()] = state
|
|
|
+ if state.get_parent_state() is not None:
|
|
|
+ parent_state_name, _dot, state_name = state.get_full_name().rpartition(
|
|
|
+ "."
|
|
|
+ )
|
|
|
+ parent_state = flat_state_tree.get(parent_state_name)
|
|
|
+ if parent_state is None:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Parent state for {state.get_full_name()} was not found "
|
|
|
+ "in the state tree, but should have already been fetched. "
|
|
|
+ "This is a bug",
|
|
|
+ )
|
|
|
+ parent_state.substates[state_name] = state
|
|
|
+ state.parent_state = parent_state
|
|
|
+
|
|
|
+ # To retain compatibility with previous implementation, by default, we return
|
|
|
+ # the top-level state which should always be fetched or already cached.
|
|
|
+ if top_level:
|
|
|
+ return flat_state_tree[self.state.get_full_name()]
|
|
|
+ return flat_state_tree[state_cls.get_full_name()]
|
|
|
+
|
|
|
+ @override
|
|
|
+ async def set_state(
|
|
|
+ self,
|
|
|
+ token: str,
|
|
|
+ state: BaseState,
|
|
|
+ lock_id: bytes | None = None,
|
|
|
+ ):
|
|
|
+ """Set the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to set the state for.
|
|
|
+ state: The state to set.
|
|
|
+ lock_id: If provided, the lock_key must be set to this value to set the state.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
|
|
|
+ RuntimeError: If the state instance doesn't match the state name in the token.
|
|
|
+ """
|
|
|
+ # Check that we're holding the lock.
|
|
|
+ if (
|
|
|
+ lock_id is not None
|
|
|
+ and await self.redis.get(self._lock_key(token)) != lock_id
|
|
|
+ ):
|
|
|
+ raise LockExpiredError(
|
|
|
+ f"Lock expired for token {token} while processing. Consider increasing "
|
|
|
+ f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
|
|
+ "or use `@rx.event(background=True)` decorator for long-running tasks."
|
|
|
+ )
|
|
|
+ elif lock_id is not None:
|
|
|
+ time_taken = self.lock_expiration / 1000 - (
|
|
|
+ await self.redis.ttl(self._lock_key(token))
|
|
|
+ )
|
|
|
+ if time_taken > self.lock_warning_threshold / 1000:
|
|
|
+ console.warn(
|
|
|
+ f"Lock for token {token} was held too long {time_taken=}s, "
|
|
|
+ f"use `@rx.event(background=True)` decorator for long-running tasks.",
|
|
|
+ dedupe=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ client_token, substate_name = _split_substate_key(token)
|
|
|
+ # If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
|
|
+ if state.parent_state is not None and state.get_full_name() != substate_name:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
|
|
|
+ )
|
|
|
+
|
|
|
+ # Recursively set_state on all known substates.
|
|
|
+ tasks = [
|
|
|
+ asyncio.create_task(
|
|
|
+ self.set_state(
|
|
|
+ _substate_key(client_token, substate),
|
|
|
+ substate,
|
|
|
+ lock_id,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ for substate in state.substates.values()
|
|
|
+ ]
|
|
|
+ # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
|
|
+ if state._get_was_touched():
|
|
|
+ pickle_state = state._serialize()
|
|
|
+ if pickle_state:
|
|
|
+ await self.redis.set(
|
|
|
+ _substate_key(client_token, state),
|
|
|
+ pickle_state,
|
|
|
+ ex=self.token_expiration,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Wait for substates to be persisted.
|
|
|
+ for t in tasks:
|
|
|
+ await t
|
|
|
+
|
|
|
+ @override
|
|
|
+ @contextlib.asynccontextmanager
|
|
|
+ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
|
+ """Modify the state for a token while holding exclusive lock.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to modify the state for.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ The state for the token.
|
|
|
+ """
|
|
|
+ async with self._lock(token) as lock_id:
|
|
|
+ state = await self.get_state(token)
|
|
|
+ yield state
|
|
|
+ await self.set_state(token, state, lock_id)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _lock_key(token: str) -> bytes:
|
|
|
+ """Get the redis key for a token's lock.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the lock key for.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The redis lock key for the token.
|
|
|
+ """
|
|
|
+ # All substates share the same lock domain, so ignore any substate path suffix.
|
|
|
+ client_token = _split_substate_key(token)[0]
|
|
|
+ return f"{client_token}_lock".encode()
|
|
|
+
|
|
|
+ async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
|
|
|
+ """Try to get a redis lock for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ lock_key: The redis key for the lock.
|
|
|
+ lock_id: The ID of the lock.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ True if the lock was obtained.
|
|
|
+ """
|
|
|
+ return await self.redis.set(
|
|
|
+ lock_key,
|
|
|
+ lock_id,
|
|
|
+ px=self.lock_expiration,
|
|
|
+ nx=True, # only set if it doesn't exist
|
|
|
+ )
|
|
|
+
|
|
|
+ async def _get_pubsub_message(
|
|
|
+ self, pubsub: PubSub, timeout: float | None = None
|
|
|
+ ) -> None:
|
|
|
+ """Get lock release events from the pubsub.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pubsub: The pubsub to get a message from.
|
|
|
+ timeout: Remaining time to wait for a message.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The message.
|
|
|
+ """
|
|
|
+ if timeout is None:
|
|
|
+ timeout = self.lock_expiration / 1000.0
|
|
|
+
|
|
|
+ started = time.time()
|
|
|
+ message = await pubsub.get_message(
|
|
|
+ ignore_subscribe_messages=True,
|
|
|
+ timeout=timeout,
|
|
|
+ )
|
|
|
+ if (
|
|
|
+ message is None
|
|
|
+ or message["data"] not in self._redis_keyspace_lock_release_events
|
|
|
+ ):
|
|
|
+ remaining = timeout - (time.time() - started)
|
|
|
+ if remaining <= 0:
|
|
|
+ return
|
|
|
+ await self._get_pubsub_message(pubsub, timeout=remaining)
|
|
|
+
|
|
|
+ async def _enable_keyspace_notifications(self):
|
|
|
+ """Enable keyspace notifications for the redis server.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ResponseError: when the keyspace config cannot be set.
|
|
|
+ """
|
|
|
+ if self._redis_notify_keyspace_events_enabled:
|
|
|
+ return
|
|
|
+ # Find out which logical database index is being used.
|
|
|
+ self._redis_db = self.redis.get_connection_kwargs().get("db", self._redis_db)
|
|
|
+
|
|
|
+ try:
|
|
|
+ await self.redis.config_set(
|
|
|
+ "notify-keyspace-events",
|
|
|
+ self._redis_notify_keyspace_events,
|
|
|
+ )
|
|
|
+ except ResponseError:
|
|
|
+ # Some redis servers only allow out-of-band configuration, so ignore errors here.
|
|
|
+ if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
|
|
|
+ raise
|
|
|
+ self._redis_notify_keyspace_events_enabled = True
|
|
|
+
|
|
|
+ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
|
|
+ """Wait for a redis lock to be released via pubsub.
|
|
|
+
|
|
|
+ Coroutine will not return until the lock is obtained.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ lock_key: The redis key for the lock.
|
|
|
+ lock_id: The ID of the lock.
|
|
|
+ """
|
|
|
+ # Enable keyspace notifications for the lock key, so we know when it is available.
|
|
|
+ await self._enable_keyspace_notifications()
|
|
|
+ lock_key_channel = f"__keyspace@{self._redis_db}__:{lock_key.decode()}"
|
|
|
+ async with self.redis.pubsub() as pubsub:
|
|
|
+ await pubsub.psubscribe(lock_key_channel)
|
|
|
+ # wait for the lock to be released
|
|
|
+ while True:
|
|
|
+ # fast path
|
|
|
+ if await self._try_get_lock(lock_key, lock_id):
|
|
|
+ return
|
|
|
+ # wait for lock events
|
|
|
+ await self._get_pubsub_message(pubsub)
|
|
|
+
|
|
|
+ @contextlib.asynccontextmanager
|
|
|
+ async def _lock(self, token: str):
|
|
|
+ """Obtain a redis lock for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to obtain a lock for.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ The ID of the lock (to be passed to set_state).
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ LockExpiredError: If the lock has expired while processing the event.
|
|
|
+ """
|
|
|
+ lock_key = self._lock_key(token)
|
|
|
+ lock_id = uuid.uuid4().hex.encode()
|
|
|
+
|
|
|
+ if not await self._try_get_lock(lock_key, lock_id):
|
|
|
+ # Missed the fast-path to get lock, subscribe for lock delete/expire events
|
|
|
+ await self._wait_lock(lock_key, lock_id)
|
|
|
+ state_is_locked = True
|
|
|
+
|
|
|
+ try:
|
|
|
+ yield lock_id
|
|
|
+ except LockExpiredError:
|
|
|
+ state_is_locked = False
|
|
|
+ raise
|
|
|
+ finally:
|
|
|
+ if state_is_locked:
|
|
|
+ # only delete our lock
|
|
|
+ await self.redis.delete(lock_key)
|
|
|
+
|
|
|
+ async def close(self):
|
|
|
+ """Explicitly close the redis connection and connection_pool.
|
|
|
+
|
|
|
+ It is necessary in testing scenarios to close between asyncio test cases
|
|
|
+ to avoid having lingering redis connections associated with event loops
|
|
|
+ that will be closed (each test case uses its own event loop).
|
|
|
+
|
|
|
+ Note: Connections will be automatically reopened when needed.
|
|
|
+ """
|
|
|
+ await self.redis.aclose(close_connection_pool=True)
|
|
|
+
|
|
|
+
|
|
|
+def get_state_manager() -> StateManager:
|
|
|
+ """Get the state manager for the app that is currently running.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state manager.
|
|
|
+ """
|
|
|
+ return prerequisites.get_and_validate_app().app.state_manager
|