|
@@ -11,6 +11,7 @@ import os
|
|
|
import uuid
|
|
|
from abc import ABC, abstractmethod
|
|
|
from collections import defaultdict
|
|
|
+from pathlib import Path
|
|
|
from types import FunctionType, MethodType
|
|
|
from typing import (
|
|
|
TYPE_CHECKING,
|
|
@@ -23,6 +24,7 @@ from typing import (
|
|
|
Optional,
|
|
|
Sequence,
|
|
|
Set,
|
|
|
+ Tuple,
|
|
|
Type,
|
|
|
Union,
|
|
|
cast,
|
|
@@ -52,7 +54,7 @@ from reflex.event import (
|
|
|
EventSpec,
|
|
|
fix_events,
|
|
|
)
|
|
|
-from reflex.utils import console, format, prerequisites, types
|
|
|
+from reflex.utils import console, format, path_ops, prerequisites, types
|
|
|
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
|
|
from reflex.utils.exec import is_testing_env
|
|
|
from reflex.utils.serializers import SerializedType, serialize, serializer
|
|
@@ -2339,7 +2341,7 @@ class StateManager(Base, ABC):
|
|
|
token_expiration=config.redis_token_expiration,
|
|
|
lock_expiration=config.redis_lock_expiration,
|
|
|
)
|
|
|
- return StateManagerMemory(state=state)
|
|
|
+ return StateManagerDisk(state=state)
|
|
|
|
|
|
@abstractmethod
|
|
|
async def get_state(self, token: str) -> BaseState:
|
|
@@ -2446,6 +2448,244 @@ class StateManagerMemory(StateManager):
|
|
|
await self.set_state(token, state)
|
|
|
|
|
|
|
|
|
+def _default_token_expiration() -> int:
|
|
|
+ """Get the default token expiration time.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The default token expiration time.
|
|
|
+ """
|
|
|
+ return get_config().redis_token_expiration
|
|
|
+
|
|
|
+
|
|
|
+def state_to_schema(
|
|
|
+ state: BaseState,
|
|
|
+) -> List[
|
|
|
+ Tuple[
|
|
|
+ str,
|
|
|
+ str,
|
|
|
+ Any,
|
|
|
+ Union[bool, None],
|
|
|
+ ]
|
|
|
+]:
|
|
|
+ """Convert a state to a schema.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state: The state to convert to a schema.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The schema.
|
|
|
+ """
|
|
|
+ return list(
|
|
|
+ sorted(
|
|
|
+ (
|
|
|
+ field_name,
|
|
|
+ model_field.name,
|
|
|
+ model_field.type_,
|
|
|
+ (
|
|
|
+ model_field.required
|
|
|
+ if isinstance(model_field.required, bool)
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ for field_name, model_field in state.__fields__.items()
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class StateManagerDisk(StateManager):
|
|
|
+ """A state manager that stores states in memory."""
|
|
|
+
|
|
|
+ # The mapping of client ids to states.
|
|
|
+ states: Dict[str, BaseState] = {}
|
|
|
+
|
|
|
+ # The mutex ensures the dict of mutexes is updated exclusively
|
|
|
+ _state_manager_lock = asyncio.Lock()
|
|
|
+
|
|
|
+ # The dict of mutexes for each client
|
|
|
+ _states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
|
|
|
+
|
|
|
+ # The token expiration time (s).
|
|
|
+ token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
|
|
|
+
|
|
|
+ class Config:
|
|
|
+ """The Pydantic config."""
|
|
|
+
|
|
|
+ fields = {
|
|
|
+ "_states_locks": {"exclude": True},
|
|
|
+ }
|
|
|
+ keep_untouched = (functools.cached_property,)
|
|
|
+
|
|
|
+ def __init__(self, state: Type[BaseState]):
|
|
|
+ """Create a new state manager.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state: The state class to use.
|
|
|
+ """
|
|
|
+ super().__init__(state=state)
|
|
|
+
|
|
|
+ path_ops.mkdir(self.states_directory)
|
|
|
+
|
|
|
+ self._purge_expired_states()
|
|
|
+
|
|
|
+ @functools.cached_property
|
|
|
+ def states_directory(self) -> Path:
|
|
|
+ """Get the states directory.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The states directory.
|
|
|
+ """
|
|
|
+ return prerequisites.get_web_dir() / constants.Dirs.STATES
|
|
|
+
|
|
|
+ def _purge_expired_states(self):
|
|
|
+ """Purge expired states from the disk."""
|
|
|
+ import time
|
|
|
+
|
|
|
+ for path in path_ops.ls(self.states_directory):
|
|
|
+ # check path is a pickle file
|
|
|
+ if path.suffix != ".pkl":
|
|
|
+ continue
|
|
|
+
|
|
|
+ # load last edited field from file
|
|
|
+ last_edited = path.stat().st_mtime
|
|
|
+
|
|
|
+ # check if the file is older than the token expiration time
|
|
|
+ if time.time() - last_edited > self.token_expiration:
|
|
|
+ # remove the file
|
|
|
+ path.unlink()
|
|
|
+
|
|
|
+ def token_path(self, token: str) -> Path:
|
|
|
+ """Get the path for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the path for.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The path for the token.
|
|
|
+ """
|
|
|
+ return (self.states_directory / f"{token}.pkl").absolute()
|
|
|
+
|
|
|
+ async def load_state(self, token: str, root_state: BaseState) -> BaseState:
|
|
|
+ """Load a state object based on the provided token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token used to identify the state object.
|
|
|
+ root_state: The root state object.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The loaded state object.
|
|
|
+ """
|
|
|
+ if token in self.states:
|
|
|
+ return self.states[token]
|
|
|
+
|
|
|
+ client_token, substate_address = _split_substate_key(token)
|
|
|
+
|
|
|
+ token_path = self.token_path(token)
|
|
|
+
|
|
|
+ if token_path.exists():
|
|
|
+ try:
|
|
|
+ with token_path.open(mode="rb") as file:
|
|
|
+ (substate_schema, substate) = dill.load(file)
|
|
|
+ if substate_schema == state_to_schema(substate):
|
|
|
+ await self.populate_substates(client_token, substate, root_state)
|
|
|
+ return substate
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ return root_state.get_substate(substate_address.split(".")[1:])
|
|
|
+
|
|
|
+ async def populate_substates(
|
|
|
+ self, client_token: str, state: BaseState, root_state: BaseState
|
|
|
+ ):
|
|
|
+ """Populate the substates of a state object.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ client_token: The client token.
|
|
|
+ state: The state object to populate.
|
|
|
+ root_state: The root state object.
|
|
|
+ """
|
|
|
+ for substate in state.get_substates():
|
|
|
+ substate_token = _substate_key(client_token, substate)
|
|
|
+
|
|
|
+ substate = await self.load_state(substate_token, root_state)
|
|
|
+
|
|
|
+ state.substates[substate.get_name()] = substate
|
|
|
+ substate.parent_state = state
|
|
|
+
|
|
|
+ @override
|
|
|
+ async def get_state(
|
|
|
+ self,
|
|
|
+ token: str,
|
|
|
+ ) -> BaseState:
|
|
|
+ """Get the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to get the state for.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The state for the token.
|
|
|
+ """
|
|
|
+ client_token, substate_address = _split_substate_key(token)
|
|
|
+
|
|
|
+ root_state_token = _substate_key(client_token, substate_address.split(".")[0])
|
|
|
+
|
|
|
+ return await self.load_state(
|
|
|
+ root_state_token, self.state(_reflex_internal_init=True)
|
|
|
+ )
|
|
|
+
|
|
|
+ async def set_state_for_substate(self, client_token: str, substate: BaseState):
|
|
|
+ """Set the state for a substate.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ client_token: The client token.
|
|
|
+ substate: The substate to set.
|
|
|
+ """
|
|
|
+ substate_token = _substate_key(client_token, substate)
|
|
|
+
|
|
|
+ self.states[substate_token] = substate
|
|
|
+
|
|
|
+ state_dilled = dill.dumps((state_to_schema(substate), substate), byref=True)
|
|
|
+ if not self.states_directory.exists():
|
|
|
+ self.states_directory.mkdir(parents=True, exist_ok=True)
|
|
|
+ self.token_path(substate_token).write_bytes(state_dilled)
|
|
|
+
|
|
|
+ for substate_substate in substate.substates.values():
|
|
|
+ await self.set_state_for_substate(client_token, substate_substate)
|
|
|
+
|
|
|
+ @override
|
|
|
+ async def set_state(self, token: str, state: BaseState):
|
|
|
+ """Set the state for a token.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to set the state for.
|
|
|
+ state: The state to set.
|
|
|
+ """
|
|
|
+ client_token, substate = _split_substate_key(token)
|
|
|
+ await self.set_state_for_substate(client_token, state)
|
|
|
+
|
|
|
+ @override
|
|
|
+ @contextlib.asynccontextmanager
|
|
|
+ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
|
+ """Modify the state for a token while holding exclusive lock.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to modify the state for.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ The state for the token.
|
|
|
+ """
|
|
|
+ # Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
+ client_token, substate = _split_substate_key(token)
|
|
|
+ if client_token not in self._states_locks:
|
|
|
+ async with self._state_manager_lock:
|
|
|
+ if client_token not in self._states_locks:
|
|
|
+ self._states_locks[client_token] = asyncio.Lock()
|
|
|
+
|
|
|
+ async with self._states_locks[client_token]:
|
|
|
+ state = await self.get_state(token)
|
|
|
+ yield state
|
|
|
+ await self.set_state(token, state)
|
|
|
+
|
|
|
+
|
|
|
# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
|
|
|
if not isinstance(State.validate.__func__, FunctionType):
|
|
|
cython_function_or_method = type(State.validate.__func__)
|
|
@@ -2474,15 +2714,6 @@ def _default_lock_expiration() -> int:
|
|
|
return get_config().redis_lock_expiration
|
|
|
|
|
|
|
|
|
-def _default_token_expiration() -> int:
|
|
|
- """Get the default token expiration time.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The default token expiration time.
|
|
|
- """
|
|
|
- return get_config().redis_token_expiration
|
|
|
-
|
|
|
-
|
|
|
class StateManagerRedis(StateManager):
|
|
|
"""A state manager that stores states in redis."""
|
|
|
|