123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858 |
- """State manager for managing client states."""
- import asyncio
- import contextlib
- import dataclasses
- import functools
- import time
- import uuid
- from abc import ABC, abstractmethod
- from collections.abc import AsyncIterator
- from hashlib import md5
- from pathlib import Path
- from redis import ResponseError
- from redis.asyncio import Redis
- from redis.asyncio.client import PubSub
- from typing_extensions import override
- from reflex import constants
- from reflex.config import environment, get_config
- from reflex.state import BaseState, _split_substate_key, _substate_key
- from reflex.utils import console, path_ops, prerequisites
- from reflex.utils.exceptions import (
- InvalidLockWarningThresholdError,
- InvalidStateManagerModeError,
- LockExpiredError,
- StateSchemaMismatchError,
- )
- @dataclasses.dataclass
- class StateManager(ABC):
- """A class to manage many client states."""
- # The state class to use.
- state: type[BaseState]
- @classmethod
- def create(cls, state: type[BaseState]):
- """Create a new state manager.
- Args:
- state: The state class to use.
- Raises:
- InvalidStateManagerModeError: If the state manager mode is invalid.
- Returns:
- The state manager (either disk, memory or redis).
- """
- config = get_config()
- if prerequisites.parse_redis_url() is not None:
- config.state_manager_mode = constants.StateManagerMode.REDIS
- if config.state_manager_mode == constants.StateManagerMode.MEMORY:
- return StateManagerMemory(state=state)
- if config.state_manager_mode == constants.StateManagerMode.DISK:
- return StateManagerDisk(state=state)
- if config.state_manager_mode == constants.StateManagerMode.REDIS:
- redis = prerequisites.get_redis()
- if redis is not None:
- # make sure expiration values are obtained only from the config object on creation
- return StateManagerRedis(
- state=state,
- redis=redis,
- token_expiration=config.redis_token_expiration,
- lock_expiration=config.redis_lock_expiration,
- lock_warning_threshold=config.redis_lock_warning_threshold,
- )
- raise InvalidStateManagerModeError(
- f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
- )
- @abstractmethod
- async def get_state(self, token: str) -> BaseState:
- """Get the state for a token.
- Args:
- token: The token to get the state for.
- Returns:
- The state for the token.
- """
- pass
- @abstractmethod
- async def set_state(self, token: str, state: BaseState):
- """Set the state for a token.
- Args:
- token: The token to set the state for.
- state: The state to set.
- """
- pass
- @abstractmethod
- @contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
- """Modify the state for a token while holding exclusive lock.
- Args:
- token: The token to modify the state for.
- Yields:
- The state for the token.
- """
- yield self.state()
- @dataclasses.dataclass
- class StateManagerMemory(StateManager):
- """A state manager that stores states in memory."""
- # The mapping of client ids to states.
- states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
- # The mutex ensures the dict of mutexes is updated exclusively
- _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
- # The dict of mutexes for each client
- _states_locks: dict[str, asyncio.Lock] = dataclasses.field(
- default_factory=dict, init=False
- )
- @override
- async def get_state(self, token: str) -> BaseState:
- """Get the state for a token.
- Args:
- token: The token to get the state for.
- Returns:
- The state for the token.
- """
- # Memory state manager ignores the substate suffix and always returns the top-level state.
- token = _split_substate_key(token)[0]
- if token not in self.states:
- self.states[token] = self.state(_reflex_internal_init=True)
- return self.states[token]
- @override
- async def set_state(self, token: str, state: BaseState):
- """Set the state for a token.
- Args:
- token: The token to set the state for.
- state: The state to set.
- """
- pass
- @override
- @contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
- """Modify the state for a token while holding exclusive lock.
- Args:
- token: The token to modify the state for.
- Yields:
- The state for the token.
- """
- # Memory state manager ignores the substate suffix and always returns the top-level state.
- token = _split_substate_key(token)[0]
- if token not in self._states_locks:
- async with self._state_manager_lock:
- if token not in self._states_locks:
- self._states_locks[token] = asyncio.Lock()
- async with self._states_locks[token]:
- state = await self.get_state(token)
- yield state
- await self.set_state(token, state)
- def _default_token_expiration() -> int:
- """Get the default token expiration time.
- Returns:
- The default token expiration time.
- """
- return get_config().redis_token_expiration
- def reset_disk_state_manager():
- """Reset the disk state manager."""
- states_directory = prerequisites.get_states_dir()
- if states_directory.exists():
- for path in states_directory.iterdir():
- path.unlink()
- @dataclasses.dataclass
- class StateManagerDisk(StateManager):
- """A state manager that stores states in memory."""
- # The mapping of client ids to states.
- states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
- # The mutex ensures the dict of mutexes is updated exclusively
- _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
- # The dict of mutexes for each client
- _states_locks: dict[str, asyncio.Lock] = dataclasses.field(
- default_factory=dict,
- init=False,
- )
- # The token expiration time (s).
- token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
- def __post_init_(self):
- """Create a new state manager."""
- path_ops.mkdir(self.states_directory)
- self._purge_expired_states()
- @functools.cached_property
- def states_directory(self) -> Path:
- """Get the states directory.
- Returns:
- The states directory.
- """
- return prerequisites.get_states_dir()
- def _purge_expired_states(self):
- """Purge expired states from the disk."""
- import time
- for path in path_ops.ls(self.states_directory):
- # check path is a pickle file
- if path.suffix != ".pkl":
- continue
- # load last edited field from file
- last_edited = path.stat().st_mtime
- # check if the file is older than the token expiration time
- if time.time() - last_edited > self.token_expiration:
- # remove the file
- path.unlink()
- def token_path(self, token: str) -> Path:
- """Get the path for a token.
- Args:
- token: The token to get the path for.
- Returns:
- The path for the token.
- """
- return (
- self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
- ).absolute()
- async def load_state(self, token: str) -> BaseState | None:
- """Load a state object based on the provided token.
- Args:
- token: The token used to identify the state object.
- Returns:
- The loaded state object or None.
- """
- token_path = self.token_path(token)
- if token_path.exists():
- try:
- with token_path.open(mode="rb") as file:
- return BaseState._deserialize(fp=file)
- except Exception:
- pass
- async def populate_substates(
- self, client_token: str, state: BaseState, root_state: BaseState
- ):
- """Populate the substates of a state object.
- Args:
- client_token: The client token.
- state: The state object to populate.
- root_state: The root state object.
- """
- for substate in state.get_substates():
- substate_token = _substate_key(client_token, substate)
- fresh_instance = await root_state.get_state(substate)
- instance = await self.load_state(substate_token)
- if instance is not None:
- # Ensure all substates exist, even if they weren't serialized previously.
- instance.substates = fresh_instance.substates
- else:
- instance = fresh_instance
- state.substates[substate.get_name()] = instance
- instance.parent_state = state
- await self.populate_substates(client_token, instance, root_state)
- @override
- async def get_state(
- self,
- token: str,
- ) -> BaseState:
- """Get the state for a token.
- Args:
- token: The token to get the state for.
- Returns:
- The state for the token.
- """
- client_token = _split_substate_key(token)[0]
- root_state = self.states.get(client_token)
- if root_state is not None:
- # Retrieved state from memory.
- return root_state
- # Deserialize root state from disk.
- root_state = await self.load_state(_substate_key(client_token, self.state))
- # Create a new root state tree with all substates instantiated.
- fresh_root_state = self.state(_reflex_internal_init=True)
- if root_state is None:
- root_state = fresh_root_state
- else:
- # Ensure all substates exist, even if they were not serialized previously.
- root_state.substates = fresh_root_state.substates
- self.states[client_token] = root_state
- await self.populate_substates(client_token, root_state, root_state)
- return root_state
- async def set_state_for_substate(self, client_token: str, substate: BaseState):
- """Set the state for a substate.
- Args:
- client_token: The client token.
- substate: The substate to set.
- """
- substate_token = _substate_key(client_token, substate)
- if substate._get_was_touched():
- substate._was_touched = False # Reset the touched flag after serializing.
- pickle_state = substate._serialize()
- if pickle_state:
- if not self.states_directory.exists():
- self.states_directory.mkdir(parents=True, exist_ok=True)
- self.token_path(substate_token).write_bytes(pickle_state)
- for substate_substate in substate.substates.values():
- await self.set_state_for_substate(client_token, substate_substate)
- @override
- async def set_state(self, token: str, state: BaseState):
- """Set the state for a token.
- Args:
- token: The token to set the state for.
- state: The state to set.
- """
- client_token, substate = _split_substate_key(token)
- await self.set_state_for_substate(client_token, state)
- @override
- @contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
- """Modify the state for a token while holding exclusive lock.
- Args:
- token: The token to modify the state for.
- Yields:
- The state for the token.
- """
- # Memory state manager ignores the substate suffix and always returns the top-level state.
- client_token, substate = _split_substate_key(token)
- if client_token not in self._states_locks:
- async with self._state_manager_lock:
- if client_token not in self._states_locks:
- self._states_locks[client_token] = asyncio.Lock()
- async with self._states_locks[client_token]:
- state = await self.get_state(token)
- yield state
- await self.set_state(token, state)
- def _default_lock_expiration() -> int:
- """Get the default lock expiration time.
- Returns:
- The default lock expiration time.
- """
- return get_config().redis_lock_expiration
- def _default_lock_warning_threshold() -> int:
- """Get the default lock warning threshold.
- Returns:
- The default lock warning threshold.
- """
- return get_config().redis_lock_warning_threshold
- @dataclasses.dataclass
- class StateManagerRedis(StateManager):
- """A state manager that stores states in redis."""
- # The redis client to use.
- redis: Redis
- # The token expiration time (s).
- token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
- # The maximum time to hold a lock (ms).
- lock_expiration: int = dataclasses.field(default_factory=_default_lock_expiration)
- # The maximum time to hold a lock (ms) before warning.
- lock_warning_threshold: int = dataclasses.field(
- default_factory=_default_lock_warning_threshold
- )
- # The keyspace subscription string when redis is waiting for lock to be released.
- _redis_notify_keyspace_events: str = dataclasses.field(
- default="K" # Enable keyspace notifications (target a particular key)
- "g" # For generic commands (DEL, EXPIRE, etc)
- "x" # For expired events
- "e" # For evicted events (i.e. maxmemory exceeded)
- )
- # These events indicate that a lock is no longer held.
- _redis_keyspace_lock_release_events: set[bytes] = dataclasses.field(
- default_factory=lambda: {
- b"del",
- b"expire",
- b"expired",
- b"evicted",
- }
- )
- # Whether keyspace notifications have been enabled.
- _redis_notify_keyspace_events_enabled: bool = dataclasses.field(default=False)
- # The logical database number used by the redis client.
- _redis_db: int = dataclasses.field(default=0)
- def __post_init__(self):
- """Validate the lock warning threshold.
- Raises:
- InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
- """
- if self.lock_warning_threshold >= (lock_expiration := self.lock_expiration):
- raise InvalidLockWarningThresholdError(
- f"The lock warning threshold({self.lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
- )
- def _get_required_state_classes(
- self,
- target_state_cls: type[BaseState],
- subclasses: bool = False,
- required_state_classes: set[type[BaseState]] | None = None,
- ) -> set[type[BaseState]]:
- """Recursively determine which states are required to fetch the target state.
- This will always include potentially dirty substates that depend on vars
- in the target_state_cls.
- Args:
- target_state_cls: The target state class being fetched.
- subclasses: Whether to include subclasses of the target state.
- required_state_classes: Recursive argument tracking state classes that have already been seen.
- Returns:
- The set of state classes required to fetch the target state.
- """
- if required_state_classes is None:
- required_state_classes = set()
- # Get the substates if requested.
- if subclasses:
- for substate in target_state_cls.get_substates():
- self._get_required_state_classes(
- substate,
- subclasses=True,
- required_state_classes=required_state_classes,
- )
- if target_state_cls in required_state_classes:
- return required_state_classes
- required_state_classes.add(target_state_cls)
- # Get dependent substates.
- for pd_substates in target_state_cls._get_potentially_dirty_states():
- self._get_required_state_classes(
- pd_substates,
- subclasses=False,
- required_state_classes=required_state_classes,
- )
- # Get the parent state if it exists.
- if parent_state := target_state_cls.get_parent_state():
- self._get_required_state_classes(
- parent_state,
- subclasses=False,
- required_state_classes=required_state_classes,
- )
- return required_state_classes
- def _get_populated_states(
- self,
- target_state: BaseState,
- populated_states: dict[str, BaseState] | None = None,
- ) -> dict[str, BaseState]:
- """Recursively determine which states from target_state are already fetched.
- Args:
- target_state: The state to check for populated states.
- populated_states: Recursive argument tracking states seen in previous calls.
- Returns:
- A dictionary of state full name to state instance.
- """
- if populated_states is None:
- populated_states = {}
- if target_state.get_full_name() in populated_states:
- return populated_states
- populated_states[target_state.get_full_name()] = target_state
- for substate in target_state.substates.values():
- self._get_populated_states(substate, populated_states=populated_states)
- if target_state.parent_state is not None:
- self._get_populated_states(
- target_state.parent_state, populated_states=populated_states
- )
- return populated_states
- @override
- async def get_state(
- self,
- token: str,
- top_level: bool = True,
- for_state_instance: BaseState | None = None,
- ) -> BaseState:
- """Get the state for a token.
- Args:
- token: The token to get the state for.
- top_level: If true, return an instance of the top-level state (self.state).
- for_state_instance: If provided, attach the requested states to this existing state tree.
- Returns:
- The state for the token.
- Raises:
- RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
- requested state was not fetched.
- """
- # Split the actual token from the fully qualified substate name.
- token, state_path = _split_substate_key(token)
- if state_path:
- # Get the State class associated with the given path.
- state_cls = self.state.get_class_substate(state_path)
- else:
- raise RuntimeError(
- f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
- )
- # Determine which states we already have.
- flat_state_tree: dict[str, BaseState] = (
- self._get_populated_states(for_state_instance) if for_state_instance else {}
- )
- # Determine which states from the tree need to be fetched.
- required_state_classes = sorted(
- self._get_required_state_classes(state_cls, subclasses=True)
- - {type(s) for s in flat_state_tree.values()},
- key=lambda x: x.get_full_name(),
- )
- redis_pipeline = self.redis.pipeline()
- for state_cls in required_state_classes:
- redis_pipeline.get(_substate_key(token, state_cls))
- for state_cls, redis_state in zip(
- required_state_classes,
- await redis_pipeline.execute(),
- strict=False,
- ):
- state = None
- if redis_state is not None:
- # Deserialize the substate.
- with contextlib.suppress(StateSchemaMismatchError):
- state = BaseState._deserialize(data=redis_state)
- if state is None:
- # Key didn't exist or schema mismatch so create a new instance for this token.
- state = state_cls(
- init_substates=False,
- _reflex_internal_init=True,
- )
- flat_state_tree[state.get_full_name()] = state
- if state.get_parent_state() is not None:
- parent_state_name, _dot, state_name = state.get_full_name().rpartition(
- "."
- )
- parent_state = flat_state_tree.get(parent_state_name)
- if parent_state is None:
- raise RuntimeError(
- f"Parent state for {state.get_full_name()} was not found "
- "in the state tree, but should have already been fetched. "
- "This is a bug",
- )
- parent_state.substates[state_name] = state
- state.parent_state = parent_state
- # To retain compatibility with previous implementation, by default, we return
- # the top-level state which should always be fetched or already cached.
- if top_level:
- return flat_state_tree[self.state.get_full_name()]
- return flat_state_tree[state_cls.get_full_name()]
- @override
- async def set_state(
- self,
- token: str,
- state: BaseState,
- lock_id: bytes | None = None,
- ):
- """Set the state for a token.
- Args:
- token: The token to set the state for.
- state: The state to set.
- lock_id: If provided, the lock_key must be set to this value to set the state.
- Raises:
- LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
- RuntimeError: If the state instance doesn't match the state name in the token.
- """
- # Check that we're holding the lock.
- if (
- lock_id is not None
- and await self.redis.get(self._lock_key(token)) != lock_id
- ):
- raise LockExpiredError(
- f"Lock expired for token {token} while processing. Consider increasing "
- f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
- "or use `@rx.event(background=True)` decorator for long-running tasks."
- )
- elif lock_id is not None:
- time_taken = self.lock_expiration / 1000 - (
- await self.redis.ttl(self._lock_key(token))
- )
- if time_taken > self.lock_warning_threshold / 1000:
- console.warn(
- f"Lock for token {token} was held too long {time_taken=}s, "
- f"use `@rx.event(background=True)` decorator for long-running tasks.",
- dedupe=True,
- )
- client_token, substate_name = _split_substate_key(token)
- # If the substate name on the token doesn't match the instance name, it cannot have a parent.
- if state.parent_state is not None and state.get_full_name() != substate_name:
- raise RuntimeError(
- f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
- )
- # Recursively set_state on all known substates.
- tasks = [
- asyncio.create_task(
- self.set_state(
- _substate_key(client_token, substate),
- substate,
- lock_id,
- )
- )
- for substate in state.substates.values()
- ]
- # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
- if state._get_was_touched():
- pickle_state = state._serialize()
- if pickle_state:
- await self.redis.set(
- _substate_key(client_token, state),
- pickle_state,
- ex=self.token_expiration,
- )
- # Wait for substates to be persisted.
- for t in tasks:
- await t
- @override
- @contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
- """Modify the state for a token while holding exclusive lock.
- Args:
- token: The token to modify the state for.
- Yields:
- The state for the token.
- """
- async with self._lock(token) as lock_id:
- state = await self.get_state(token)
- yield state
- await self.set_state(token, state, lock_id)
- @staticmethod
- def _lock_key(token: str) -> bytes:
- """Get the redis key for a token's lock.
- Args:
- token: The token to get the lock key for.
- Returns:
- The redis lock key for the token.
- """
- # All substates share the same lock domain, so ignore any substate path suffix.
- client_token = _split_substate_key(token)[0]
- return f"{client_token}_lock".encode()
- async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
- """Try to get a redis lock for a token.
- Args:
- lock_key: The redis key for the lock.
- lock_id: The ID of the lock.
- Returns:
- True if the lock was obtained.
- """
- return await self.redis.set(
- lock_key,
- lock_id,
- px=self.lock_expiration,
- nx=True, # only set if it doesn't exist
- )
- async def _get_pubsub_message(
- self, pubsub: PubSub, timeout: float | None = None
- ) -> None:
- """Get lock release events from the pubsub.
- Args:
- pubsub: The pubsub to get a message from.
- timeout: Remaining time to wait for a message.
- Returns:
- The message.
- """
- if timeout is None:
- timeout = self.lock_expiration / 1000.0
- started = time.time()
- message = await pubsub.get_message(
- ignore_subscribe_messages=True,
- timeout=timeout,
- )
- if (
- message is None
- or message["data"] not in self._redis_keyspace_lock_release_events
- ):
- remaining = timeout - (time.time() - started)
- if remaining <= 0:
- return
- await self._get_pubsub_message(pubsub, timeout=remaining)
- async def _enable_keyspace_notifications(self):
- """Enable keyspace notifications for the redis server.
- Raises:
- ResponseError: when the keyspace config cannot be set.
- """
- if self._redis_notify_keyspace_events_enabled:
- return
- # Find out which logical database index is being used.
- self._redis_db = self.redis.get_connection_kwargs().get("db", self._redis_db)
- try:
- await self.redis.config_set(
- "notify-keyspace-events",
- self._redis_notify_keyspace_events,
- )
- except ResponseError:
- # Some redis servers only allow out-of-band configuration, so ignore errors here.
- if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
- raise
- self._redis_notify_keyspace_events_enabled = True
- async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
- """Wait for a redis lock to be released via pubsub.
- Coroutine will not return until the lock is obtained.
- Args:
- lock_key: The redis key for the lock.
- lock_id: The ID of the lock.
- """
- # Enable keyspace notifications for the lock key, so we know when it is available.
- await self._enable_keyspace_notifications()
- lock_key_channel = f"__keyspace@{self._redis_db}__:{lock_key.decode()}"
- async with self.redis.pubsub() as pubsub:
- await pubsub.psubscribe(lock_key_channel)
- # wait for the lock to be released
- while True:
- # fast path
- if await self._try_get_lock(lock_key, lock_id):
- return
- # wait for lock events
- await self._get_pubsub_message(pubsub)
- @contextlib.asynccontextmanager
- async def _lock(self, token: str):
- """Obtain a redis lock for a token.
- Args:
- token: The token to obtain a lock for.
- Yields:
- The ID of the lock (to be passed to set_state).
- Raises:
- LockExpiredError: If the lock has expired while processing the event.
- """
- lock_key = self._lock_key(token)
- lock_id = uuid.uuid4().hex.encode()
- if not await self._try_get_lock(lock_key, lock_id):
- # Missed the fast-path to get lock, subscribe for lock delete/expire events
- await self._wait_lock(lock_key, lock_id)
- state_is_locked = True
- try:
- yield lock_id
- except LockExpiredError:
- state_is_locked = False
- raise
- finally:
- if state_is_locked:
- # only delete our lock
- await self.redis.delete(lock_key)
- async def close(self):
- """Explicitly close the redis connection and connection_pool.
- It is necessary in testing scenarios to close between asyncio test cases
- to avoid having lingering redis connections associated with event loops
- that will be closed (each test case uses its own event loop).
- Note: Connections will be automatically reopened when needed.
- """
- await self.redis.aclose(close_connection_pool=True)
- def get_state_manager() -> StateManager:
- """Get the state manager for the app that is currently running.
- Returns:
- The state manager.
- """
- return prerequisites.get_and_validate_app().app.state_manager
|