|
@@ -9,24 +9,19 @@ import copy
|
|
|
import dataclasses
|
|
|
import functools
|
|
|
import inspect
|
|
|
-import json
|
|
|
import pickle
|
|
|
import sys
|
|
|
-import time
|
|
|
import typing
|
|
|
-import uuid
|
|
|
import warnings
|
|
|
-from abc import ABC, abstractmethod
|
|
|
+from abc import ABC
|
|
|
from collections.abc import AsyncIterator, Callable, Sequence
|
|
|
from hashlib import md5
|
|
|
-from pathlib import Path
|
|
|
-from types import FunctionType, MethodType
|
|
|
+from types import FunctionType
|
|
|
from typing import (
|
|
|
TYPE_CHECKING,
|
|
|
Any,
|
|
|
BinaryIO,
|
|
|
ClassVar,
|
|
|
- SupportsIndex,
|
|
|
TypeVar,
|
|
|
cast,
|
|
|
get_args,
|
|
@@ -34,22 +29,16 @@ from typing import (
|
|
|
)
|
|
|
|
|
|
import pydantic.v1 as pydantic
|
|
|
-import wrapt
|
|
|
from pydantic import BaseModel as BaseModelV2
|
|
|
from pydantic.v1 import BaseModel as BaseModelV1
|
|
|
-from pydantic.v1 import validator
|
|
|
from pydantic.v1.fields import ModelField
|
|
|
-from redis.asyncio import Redis
|
|
|
-from redis.asyncio.client import PubSub
|
|
|
-from redis.exceptions import ResponseError
|
|
|
from rich.markup import escape
|
|
|
-from sqlalchemy.orm import DeclarativeBase
|
|
|
from typing_extensions import Self
|
|
|
|
|
|
import reflex.istate.dynamic
|
|
|
from reflex import constants, event
|
|
|
from reflex.base import Base
|
|
|
-from reflex.config import PerformanceMode, environment, get_config
|
|
|
+from reflex.config import PerformanceMode, environment
|
|
|
from reflex.event import (
|
|
|
BACKGROUND_TASK_MARKER,
|
|
|
Event,
|
|
@@ -58,19 +47,17 @@ from reflex.event import (
|
|
|
fix_events,
|
|
|
)
|
|
|
from reflex.istate.data import RouterData
|
|
|
+from reflex.istate.proxy import ImmutableMutableProxy as ImmutableMutableProxy
|
|
|
+from reflex.istate.proxy import MutableProxy, StateProxy
|
|
|
from reflex.istate.storage import ClientStorageBase
|
|
|
from reflex.model import Model
|
|
|
-from reflex.utils import console, format, path_ops, prerequisites, types
|
|
|
+from reflex.utils import console, format, prerequisites, types
|
|
|
from reflex.utils.exceptions import (
|
|
|
ComputedVarShadowsBaseVarsError,
|
|
|
ComputedVarShadowsStateVarError,
|
|
|
DynamicComponentInvalidSignatureError,
|
|
|
DynamicRouteArgShadowsStateVarError,
|
|
|
EventHandlerShadowsBuiltInStateMethodError,
|
|
|
- ImmutableStateError,
|
|
|
- InvalidLockWarningThresholdError,
|
|
|
- InvalidStateManagerModeError,
|
|
|
- LockExpiredError,
|
|
|
ReflexRuntimeError,
|
|
|
SetUndefinedStateVarError,
|
|
|
StateMismatchError,
|
|
@@ -79,13 +66,12 @@ from reflex.utils.exceptions import (
|
|
|
StateTooLargeError,
|
|
|
UnretrievableVarValueError,
|
|
|
)
|
|
|
+from reflex.utils.exceptions import ImmutableStateError as ImmutableStateError
|
|
|
from reflex.utils.exec import is_testing_env
|
|
|
-from reflex.utils.serializers import serializer
|
|
|
from reflex.utils.types import (
|
|
|
_isinstance,
|
|
|
get_origin,
|
|
|
is_union,
|
|
|
- override,
|
|
|
true_type_for_pydantic_field,
|
|
|
value_inside_optional,
|
|
|
)
|
|
@@ -2284,6 +2270,35 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
return state
|
|
|
|
|
|
|
|
|
+def _serialize_type(type_: Any) -> str:
|
|
|
+ """Serialize a type.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ type_: The type to serialize.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The serialized type.
|
|
|
+ """
|
|
|
+ if not inspect.isclass(type_):
|
|
|
+ return f"{type_}"
|
|
|
+ return f"{type_.__module__}.{type_.__qualname__}"
|
|
|
+
|
|
|
+
|
|
|
+def is_serializable(value: Any) -> bool:
|
|
|
+ """Check if a value is serializable.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ value: The value to check.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Whether the value is serializable.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ return bool(pickle.dumps(value))
|
|
|
+ except Exception:
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
T_STATE = TypeVar("T_STATE", bound=BaseState)
|
|
|
|
|
|
|
|
@@ -2523,278 +2538,6 @@ class ComponentState(State, mixin=True):
|
|
|
return component
|
|
|
|
|
|
|
|
|
-class StateProxy(wrapt.ObjectProxy):
|
|
|
- """Proxy of a state instance to control mutability of vars for a background task.
|
|
|
-
|
|
|
- Since a background task runs against a state instance without holding the
|
|
|
- state_manager lock for the token, the reference may become stale if the same
|
|
|
- state is modified by another event handler.
|
|
|
-
|
|
|
- The proxy object ensures that writes to the state are blocked unless
|
|
|
- explicitly entering a context which refreshes the state from state_manager
|
|
|
- and holds the lock for the token until exiting the context. After exiting
|
|
|
- the context, a StateUpdate may be emitted to the frontend to notify the
|
|
|
- client of the state change.
|
|
|
-
|
|
|
- A background task will be passed the `StateProxy` as `self`, so mutability
|
|
|
- can be safely performed inside an `async with self` block.
|
|
|
-
|
|
|
- class State(rx.State):
|
|
|
- counter: int = 0
|
|
|
-
|
|
|
- @rx.event(background=True)
|
|
|
- async def bg_increment(self):
|
|
|
- await asyncio.sleep(1)
|
|
|
- async with self:
|
|
|
- self.counter += 1
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- state_instance: BaseState,
|
|
|
- parent_state_proxy: StateProxy | None = None,
|
|
|
- ):
|
|
|
- """Create a proxy for a state instance.
|
|
|
-
|
|
|
- If `get_state` is used on a StateProxy, the resulting state will be
|
|
|
- linked to the given state via parent_state_proxy. The first state in the
|
|
|
- chain is the state that initiated the background task.
|
|
|
-
|
|
|
- Args:
|
|
|
- state_instance: The state instance to proxy.
|
|
|
- parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
|
|
|
- """
|
|
|
- super().__init__(state_instance)
|
|
|
- # compile is not relevant to backend logic
|
|
|
- self._self_app = prerequisites.get_and_validate_app().app
|
|
|
- self._self_substate_path = tuple(state_instance.get_full_name().split("."))
|
|
|
- self._self_actx = None
|
|
|
- self._self_mutable = False
|
|
|
- self._self_actx_lock = asyncio.Lock()
|
|
|
- self._self_actx_lock_holder = None
|
|
|
- self._self_parent_state_proxy = parent_state_proxy
|
|
|
-
|
|
|
- def _is_mutable(self) -> bool:
|
|
|
- """Check if the state is mutable.
|
|
|
-
|
|
|
- Returns:
|
|
|
- Whether the state is mutable.
|
|
|
- """
|
|
|
- if self._self_parent_state_proxy is not None:
|
|
|
- return self._self_parent_state_proxy._is_mutable() or self._self_mutable
|
|
|
- return self._self_mutable
|
|
|
-
|
|
|
- async def __aenter__(self) -> StateProxy:
|
|
|
- """Enter the async context manager protocol.
|
|
|
-
|
|
|
- Sets mutability to True and enters the `App.modify_state` async context,
|
|
|
- which refreshes the state from state_manager and holds the lock for the
|
|
|
- given state token until exiting the context.
|
|
|
-
|
|
|
- Background tasks should avoid blocking calls while inside the context.
|
|
|
-
|
|
|
- Returns:
|
|
|
- This StateProxy instance in mutable mode.
|
|
|
-
|
|
|
- Raises:
|
|
|
- ImmutableStateError: If the state is already mutable.
|
|
|
- """
|
|
|
- if self._self_parent_state_proxy is not None:
|
|
|
- parent_state = (
|
|
|
- await self._self_parent_state_proxy.__aenter__()
|
|
|
- ).__wrapped__
|
|
|
- super().__setattr__(
|
|
|
- "__wrapped__",
|
|
|
- await parent_state.get_state(
|
|
|
- State.get_class_substate(self._self_substate_path)
|
|
|
- ),
|
|
|
- )
|
|
|
- return self
|
|
|
- current_task = asyncio.current_task()
|
|
|
- if (
|
|
|
- self._self_actx_lock.locked()
|
|
|
- and current_task == self._self_actx_lock_holder
|
|
|
- ):
|
|
|
- raise ImmutableStateError(
|
|
|
- "The state is already mutable. Do not nest `async with self` blocks."
|
|
|
- )
|
|
|
- await self._self_actx_lock.acquire()
|
|
|
- self._self_actx_lock_holder = current_task
|
|
|
- self._self_actx = self._self_app.modify_state(
|
|
|
- token=_substate_key(
|
|
|
- self.__wrapped__.router.session.client_token,
|
|
|
- self._self_substate_path,
|
|
|
- )
|
|
|
- )
|
|
|
- mutable_state = await self._self_actx.__aenter__()
|
|
|
- super().__setattr__(
|
|
|
- "__wrapped__", mutable_state.get_substate(self._self_substate_path)
|
|
|
- )
|
|
|
- self._self_mutable = True
|
|
|
- return self
|
|
|
-
|
|
|
- async def __aexit__(self, *exc_info: Any) -> None:
|
|
|
- """Exit the async context manager protocol.
|
|
|
-
|
|
|
- Sets proxy mutability to False and persists any state changes.
|
|
|
-
|
|
|
- Args:
|
|
|
- exc_info: The exception info tuple.
|
|
|
- """
|
|
|
- if self._self_parent_state_proxy is not None:
|
|
|
- await self._self_parent_state_proxy.__aexit__(*exc_info)
|
|
|
- return
|
|
|
- if self._self_actx is None:
|
|
|
- return
|
|
|
- self._self_mutable = False
|
|
|
- try:
|
|
|
- await self._self_actx.__aexit__(*exc_info)
|
|
|
- finally:
|
|
|
- self._self_actx_lock_holder = None
|
|
|
- self._self_actx_lock.release()
|
|
|
- self._self_actx = None
|
|
|
-
|
|
|
- def __enter__(self):
|
|
|
- """Enter the regular context manager protocol.
|
|
|
-
|
|
|
- This is not supported for background tasks, and exists only to raise a more useful exception
|
|
|
- when the StateProxy is used incorrectly.
|
|
|
-
|
|
|
- Raises:
|
|
|
- TypeError: always, because only async contextmanager protocol is supported.
|
|
|
- """
|
|
|
- raise TypeError("Background task must use `async with self` to modify state.")
|
|
|
-
|
|
|
- def __exit__(self, *exc_info: Any) -> None:
|
|
|
- """Exit the regular context manager protocol.
|
|
|
-
|
|
|
- Args:
|
|
|
- exc_info: The exception info tuple.
|
|
|
- """
|
|
|
- pass
|
|
|
-
|
|
|
- def __getattr__(self, name: str) -> Any:
|
|
|
- """Get the attribute from the underlying state instance.
|
|
|
-
|
|
|
- Args:
|
|
|
- name: The name of the attribute.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The value of the attribute.
|
|
|
-
|
|
|
- Raises:
|
|
|
- ImmutableStateError: If the state is not in mutable mode.
|
|
|
- """
|
|
|
- if name in ["substates", "parent_state"] and not self._is_mutable():
|
|
|
- raise ImmutableStateError(
|
|
|
- "Background task StateProxy is immutable outside of a context "
|
|
|
- "manager. Use `async with self` to modify state."
|
|
|
- )
|
|
|
- value = super().__getattr__(name)
|
|
|
- if not name.startswith("_self_") and isinstance(value, MutableProxy):
|
|
|
- # ensure mutations to these containers are blocked unless proxy is _mutable
|
|
|
- return ImmutableMutableProxy(
|
|
|
- wrapped=value.__wrapped__,
|
|
|
- state=self,
|
|
|
- field_name=value._self_field_name,
|
|
|
- )
|
|
|
- if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
|
|
|
- # Rebind event handler to the proxy instance
|
|
|
- value = functools.partial(
|
|
|
- value.func,
|
|
|
- self,
|
|
|
- *value.args[1:],
|
|
|
- **value.keywords,
|
|
|
- )
|
|
|
- if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
|
|
|
- # Rebind methods to the proxy instance
|
|
|
- value = type(value)(value.__func__, self)
|
|
|
- return value
|
|
|
-
|
|
|
- def __setattr__(self, name: str, value: Any) -> None:
|
|
|
- """Set the attribute on the underlying state instance.
|
|
|
-
|
|
|
- If the attribute is internal, set it on the proxy instance instead.
|
|
|
-
|
|
|
- Args:
|
|
|
- name: The name of the attribute.
|
|
|
- value: The value of the attribute.
|
|
|
-
|
|
|
- Raises:
|
|
|
- ImmutableStateError: If the state is not in mutable mode.
|
|
|
- """
|
|
|
- if (
|
|
|
- name.startswith("_self_") # wrapper attribute
|
|
|
- or self._is_mutable() # lock held
|
|
|
- # non-persisted state attribute
|
|
|
- or name in self.__wrapped__.get_skip_vars()
|
|
|
- ):
|
|
|
- super().__setattr__(name, value)
|
|
|
- return
|
|
|
-
|
|
|
- raise ImmutableStateError(
|
|
|
- "Background task StateProxy is immutable outside of a context "
|
|
|
- "manager. Use `async with self` to modify state."
|
|
|
- )
|
|
|
-
|
|
|
- def get_substate(self, path: Sequence[str]) -> BaseState:
|
|
|
- """Only allow substate access with lock held.
|
|
|
-
|
|
|
- Args:
|
|
|
- path: The path to the substate.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The substate.
|
|
|
-
|
|
|
- Raises:
|
|
|
- ImmutableStateError: If the state is not in mutable mode.
|
|
|
- """
|
|
|
- if not self._is_mutable():
|
|
|
- raise ImmutableStateError(
|
|
|
- "Background task StateProxy is immutable outside of a context "
|
|
|
- "manager. Use `async with self` to modify state."
|
|
|
- )
|
|
|
- return self.__wrapped__.get_substate(path)
|
|
|
-
|
|
|
- async def get_state(self, state_cls: type[BaseState]) -> BaseState:
|
|
|
- """Get an instance of the state associated with this token.
|
|
|
-
|
|
|
- Args:
|
|
|
- state_cls: The class of the state.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The state.
|
|
|
-
|
|
|
- Raises:
|
|
|
- ImmutableStateError: If the state is not in mutable mode.
|
|
|
- """
|
|
|
- if not self._is_mutable():
|
|
|
- raise ImmutableStateError(
|
|
|
- "Background task StateProxy is immutable outside of a context "
|
|
|
- "manager. Use `async with self` to modify state."
|
|
|
- )
|
|
|
- return type(self)(
|
|
|
- await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
|
|
|
- )
|
|
|
-
|
|
|
- async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
|
|
- """Temporarily allow mutability to access parent_state.
|
|
|
-
|
|
|
- Args:
|
|
|
- *args: The args to pass to the underlying state instance.
|
|
|
- **kwargs: The kwargs to pass to the underlying state instance.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The state update.
|
|
|
- """
|
|
|
- original_mutable = self._self_mutable
|
|
|
- self._self_mutable = True
|
|
|
- try:
|
|
|
- return await self.__wrapped__._as_state_update(*args, **kwargs)
|
|
|
- finally:
|
|
|
- self._self_mutable = original_mutable
|
|
|
-
|
|
|
-
|
|
|
@dataclasses.dataclass(
|
|
|
frozen=True,
|
|
|
)
|
|
@@ -2819,1347 +2562,54 @@ class StateUpdate:
|
|
|
return format.json_dumps(self)
|
|
|
|
|
|
|
|
|
-class StateManager(Base, ABC):
|
|
|
- """A class to manage many client states."""
|
|
|
-
|
|
|
- # The state class to use.
|
|
|
- state: type[BaseState]
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create(cls, state: type[BaseState]):
|
|
|
- """Create a new state manager.
|
|
|
-
|
|
|
- Args:
|
|
|
- state: The state class to use.
|
|
|
-
|
|
|
- Raises:
|
|
|
- InvalidStateManagerModeError: If the state manager mode is invalid.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The state manager (either disk, memory or redis).
|
|
|
- """
|
|
|
- config = get_config()
|
|
|
- if prerequisites.parse_redis_url() is not None:
|
|
|
- config.state_manager_mode = constants.StateManagerMode.REDIS
|
|
|
- if config.state_manager_mode == constants.StateManagerMode.MEMORY:
|
|
|
- return StateManagerMemory(state=state)
|
|
|
- if config.state_manager_mode == constants.StateManagerMode.DISK:
|
|
|
- return StateManagerDisk(state=state)
|
|
|
- if config.state_manager_mode == constants.StateManagerMode.REDIS:
|
|
|
- redis = prerequisites.get_redis()
|
|
|
- if redis is not None:
|
|
|
- # make sure expiration values are obtained only from the config object on creation
|
|
|
- return StateManagerRedis(
|
|
|
- state=state,
|
|
|
- redis=redis,
|
|
|
- token_expiration=config.redis_token_expiration,
|
|
|
- lock_expiration=config.redis_lock_expiration,
|
|
|
- lock_warning_threshold=config.redis_lock_warning_threshold,
|
|
|
- )
|
|
|
- raise InvalidStateManagerModeError(
|
|
|
- f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
|
|
- )
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- async def get_state(self, token: str) -> BaseState:
|
|
|
- """Get the state for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to get the state for.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The state for the token.
|
|
|
- """
|
|
|
- pass
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- async def set_state(self, token: str, state: BaseState):
|
|
|
- """Set the state for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to set the state for.
|
|
|
- state: The state to set.
|
|
|
- """
|
|
|
- pass
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- @contextlib.asynccontextmanager
|
|
|
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
|
- """Modify the state for a token while holding exclusive lock.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to modify the state for.
|
|
|
-
|
|
|
- Yields:
|
|
|
- The state for the token.
|
|
|
- """
|
|
|
- yield self.state()
|
|
|
-
|
|
|
-
|
|
|
-class StateManagerMemory(StateManager):
|
|
|
- """A state manager that stores states in memory."""
|
|
|
-
|
|
|
- # The mapping of client ids to states.
|
|
|
- states: dict[str, BaseState] = {}
|
|
|
-
|
|
|
- # The mutex ensures the dict of mutexes is updated exclusively
|
|
|
- _state_manager_lock = asyncio.Lock()
|
|
|
-
|
|
|
- # The dict of mutexes for each client
|
|
|
- _states_locks: dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
|
|
|
-
|
|
|
- class Config: # pyright: ignore [reportIncompatibleVariableOverride]
|
|
|
- """The Pydantic config."""
|
|
|
-
|
|
|
- fields = {
|
|
|
- "_states_locks": {"exclude": True},
|
|
|
- }
|
|
|
-
|
|
|
- @override
|
|
|
- async def get_state(self, token: str) -> BaseState:
|
|
|
- """Get the state for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to get the state for.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The state for the token.
|
|
|
- """
|
|
|
- # Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
- token = _split_substate_key(token)[0]
|
|
|
- if token not in self.states:
|
|
|
- self.states[token] = self.state(_reflex_internal_init=True)
|
|
|
- return self.states[token]
|
|
|
-
|
|
|
- @override
|
|
|
- async def set_state(self, token: str, state: BaseState):
|
|
|
- """Set the state for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to set the state for.
|
|
|
- state: The state to set.
|
|
|
- """
|
|
|
- pass
|
|
|
-
|
|
|
- @override
|
|
|
- @contextlib.asynccontextmanager
|
|
|
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
|
- """Modify the state for a token while holding exclusive lock.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to modify the state for.
|
|
|
-
|
|
|
- Yields:
|
|
|
- The state for the token.
|
|
|
- """
|
|
|
- # Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
- token = _split_substate_key(token)[0]
|
|
|
- if token not in self._states_locks:
|
|
|
- async with self._state_manager_lock:
|
|
|
- if token not in self._states_locks:
|
|
|
- self._states_locks[token] = asyncio.Lock()
|
|
|
-
|
|
|
- async with self._states_locks[token]:
|
|
|
- state = await self.get_state(token)
|
|
|
- yield state
|
|
|
- await self.set_state(token, state)
|
|
|
-
|
|
|
-
|
|
|
-def _default_token_expiration() -> int:
|
|
|
- """Get the default token expiration time.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The default token expiration time.
|
|
|
- """
|
|
|
- return get_config().redis_token_expiration
|
|
|
-
|
|
|
-
|
|
|
-def _serialize_type(type_: Any) -> str:
|
|
|
- """Serialize a type.
|
|
|
+def code_uses_state_contexts(javascript_code: str) -> bool:
|
|
|
+ """Check if the rendered Javascript uses state contexts.
|
|
|
|
|
|
Args:
|
|
|
- type_: The type to serialize.
|
|
|
+ javascript_code: The Javascript code to check.
|
|
|
|
|
|
Returns:
|
|
|
- The serialized type.
|
|
|
+ True if the code attempts to access a member of StateContexts.
|
|
|
"""
|
|
|
- if not inspect.isclass(type_):
|
|
|
- return f"{type_}"
|
|
|
- return f"{type_.__module__}.{type_.__qualname__}"
|
|
|
+ return bool("useContext(StateContexts" in javascript_code)
|
|
|
|
|
|
|
|
|
-def is_serializable(value: Any) -> bool:
|
|
|
- """Check if a value is serializable.
|
|
|
+def reload_state_module(
|
|
|
+ module: str,
|
|
|
+ state: type[BaseState] = State,
|
|
|
+) -> None:
|
|
|
+ """Reset rx.State subclasses to avoid conflict when reloading.
|
|
|
|
|
|
Args:
|
|
|
- value: The value to check.
|
|
|
+ module: The module to reload.
|
|
|
+ state: Recursive argument for the state class to reload.
|
|
|
|
|
|
- Returns:
|
|
|
- Whether the value is serializable.
|
|
|
"""
|
|
|
- try:
|
|
|
- return bool(pickle.dumps(value))
|
|
|
- except Exception:
|
|
|
- return False
|
|
|
-
|
|
|
-
|
|
|
-def reset_disk_state_manager():
|
|
|
- """Reset the disk state manager."""
|
|
|
- states_directory = prerequisites.get_states_dir()
|
|
|
- if states_directory.exists():
|
|
|
- for path in states_directory.iterdir():
|
|
|
- path.unlink()
|
|
|
+ # Clean out all potentially dirty states of reloaded modules.
|
|
|
+ for pd_state in tuple(state._potentially_dirty_states):
|
|
|
+ with contextlib.suppress(ValueError):
|
|
|
+ if (
|
|
|
+ state.get_root_state().get_class_substate(pd_state).__module__ == module
|
|
|
+ and module is not None
|
|
|
+ ):
|
|
|
+ state._potentially_dirty_states.remove(pd_state)
|
|
|
+ for subclass in tuple(state.class_subclasses):
|
|
|
+ reload_state_module(module=module, state=subclass)
|
|
|
+ if subclass.__module__ == module and module is not None:
|
|
|
+ all_base_state_classes.pop(subclass.get_full_name(), None)
|
|
|
+ state.class_subclasses.remove(subclass)
|
|
|
+ state._always_dirty_substates.discard(subclass.get_name())
|
|
|
+ state._var_dependencies = {}
|
|
|
+ state._init_var_dependency_dicts()
|
|
|
+ state.get_class_substate.cache_clear()
|
|
|
|
|
|
|
|
|
-class StateManagerDisk(StateManager):
|
|
|
- """A state manager that stores states in memory."""
|
|
|
-
|
|
|
- # The mapping of client ids to states.
|
|
|
- states: dict[str, BaseState] = {}
|
|
|
-
|
|
|
- # The mutex ensures the dict of mutexes is updated exclusively
|
|
|
- _state_manager_lock = asyncio.Lock()
|
|
|
-
|
|
|
- # The dict of mutexes for each client
|
|
|
- _states_locks: dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
|
|
|
-
|
|
|
- # The token expiration time (s).
|
|
|
- token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
|
|
|
-
|
|
|
- class Config: # pyright: ignore [reportIncompatibleVariableOverride]
|
|
|
- """The Pydantic config."""
|
|
|
-
|
|
|
- fields = {
|
|
|
- "_states_locks": {"exclude": True},
|
|
|
- }
|
|
|
- keep_untouched = (functools.cached_property,)
|
|
|
-
|
|
|
- def __init__(self, state: type[BaseState]):
|
|
|
- """Create a new state manager.
|
|
|
-
|
|
|
- Args:
|
|
|
- state: The state class to use.
|
|
|
- """
|
|
|
- super().__init__(state=state)
|
|
|
-
|
|
|
- path_ops.mkdir(self.states_directory)
|
|
|
-
|
|
|
- self._purge_expired_states()
|
|
|
-
|
|
|
- @functools.cached_property
|
|
|
- def states_directory(self) -> Path:
|
|
|
- """Get the states directory.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The states directory.
|
|
|
- """
|
|
|
- return prerequisites.get_states_dir()
|
|
|
-
|
|
|
- def _purge_expired_states(self):
|
|
|
- """Purge expired states from the disk."""
|
|
|
- import time
|
|
|
-
|
|
|
- for path in path_ops.ls(self.states_directory):
|
|
|
- # check path is a pickle file
|
|
|
- if path.suffix != ".pkl":
|
|
|
- continue
|
|
|
-
|
|
|
- # load last edited field from file
|
|
|
- last_edited = path.stat().st_mtime
|
|
|
-
|
|
|
- # check if the file is older than the token expiration time
|
|
|
- if time.time() - last_edited > self.token_expiration:
|
|
|
- # remove the file
|
|
|
- path.unlink()
|
|
|
-
|
|
|
- def token_path(self, token: str) -> Path:
|
|
|
- """Get the path for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to get the path for.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The path for the token.
|
|
|
- """
|
|
|
- return (
|
|
|
- self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
|
|
|
- ).absolute()
|
|
|
-
|
|
|
- async def load_state(self, token: str) -> BaseState | None:
|
|
|
- """Load a state object based on the provided token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token used to identify the state object.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The loaded state object or None.
|
|
|
- """
|
|
|
- token_path = self.token_path(token)
|
|
|
-
|
|
|
- if token_path.exists():
|
|
|
- try:
|
|
|
- with token_path.open(mode="rb") as file:
|
|
|
- return BaseState._deserialize(fp=file)
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
-
|
|
|
- async def populate_substates(
|
|
|
- self, client_token: str, state: BaseState, root_state: BaseState
|
|
|
- ):
|
|
|
- """Populate the substates of a state object.
|
|
|
-
|
|
|
- Args:
|
|
|
- client_token: The client token.
|
|
|
- state: The state object to populate.
|
|
|
- root_state: The root state object.
|
|
|
- """
|
|
|
- for substate in state.get_substates():
|
|
|
- substate_token = _substate_key(client_token, substate)
|
|
|
-
|
|
|
- fresh_instance = await root_state.get_state(substate)
|
|
|
- instance = await self.load_state(substate_token)
|
|
|
- if instance is not None:
|
|
|
- # Ensure all substates exist, even if they weren't serialized previously.
|
|
|
- instance.substates = fresh_instance.substates
|
|
|
- else:
|
|
|
- instance = fresh_instance
|
|
|
- state.substates[substate.get_name()] = instance
|
|
|
- instance.parent_state = state
|
|
|
-
|
|
|
- await self.populate_substates(client_token, instance, root_state)
|
|
|
-
|
|
|
- @override
|
|
|
- async def get_state(
|
|
|
- self,
|
|
|
- token: str,
|
|
|
- ) -> BaseState:
|
|
|
- """Get the state for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to get the state for.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The state for the token.
|
|
|
- """
|
|
|
- client_token = _split_substate_key(token)[0]
|
|
|
- root_state = self.states.get(client_token)
|
|
|
- if root_state is not None:
|
|
|
- # Retrieved state from memory.
|
|
|
- return root_state
|
|
|
-
|
|
|
- # Deserialize root state from disk.
|
|
|
- root_state = await self.load_state(_substate_key(client_token, self.state))
|
|
|
- # Create a new root state tree with all substates instantiated.
|
|
|
- fresh_root_state = self.state(_reflex_internal_init=True)
|
|
|
- if root_state is None:
|
|
|
- root_state = fresh_root_state
|
|
|
- else:
|
|
|
- # Ensure all substates exist, even if they were not serialized previously.
|
|
|
- root_state.substates = fresh_root_state.substates
|
|
|
- self.states[client_token] = root_state
|
|
|
- await self.populate_substates(client_token, root_state, root_state)
|
|
|
- return root_state
|
|
|
-
|
|
|
- async def set_state_for_substate(self, client_token: str, substate: BaseState):
|
|
|
- """Set the state for a substate.
|
|
|
-
|
|
|
- Args:
|
|
|
- client_token: The client token.
|
|
|
- substate: The substate to set.
|
|
|
- """
|
|
|
- substate_token = _substate_key(client_token, substate)
|
|
|
-
|
|
|
- if substate._get_was_touched():
|
|
|
- substate._was_touched = False # Reset the touched flag after serializing.
|
|
|
- pickle_state = substate._serialize()
|
|
|
- if pickle_state:
|
|
|
- if not self.states_directory.exists():
|
|
|
- self.states_directory.mkdir(parents=True, exist_ok=True)
|
|
|
- self.token_path(substate_token).write_bytes(pickle_state)
|
|
|
-
|
|
|
- for substate_substate in substate.substates.values():
|
|
|
- await self.set_state_for_substate(client_token, substate_substate)
|
|
|
-
|
|
|
- @override
|
|
|
- async def set_state(self, token: str, state: BaseState):
|
|
|
- """Set the state for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to set the state for.
|
|
|
- state: The state to set.
|
|
|
- """
|
|
|
- client_token, substate = _split_substate_key(token)
|
|
|
- await self.set_state_for_substate(client_token, state)
|
|
|
-
|
|
|
- @override
|
|
|
- @contextlib.asynccontextmanager
|
|
|
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
|
- """Modify the state for a token while holding exclusive lock.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to modify the state for.
|
|
|
-
|
|
|
- Yields:
|
|
|
- The state for the token.
|
|
|
- """
|
|
|
- # Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
|
- client_token, substate = _split_substate_key(token)
|
|
|
- if client_token not in self._states_locks:
|
|
|
- async with self._state_manager_lock:
|
|
|
- if client_token not in self._states_locks:
|
|
|
- self._states_locks[client_token] = asyncio.Lock()
|
|
|
-
|
|
|
- async with self._states_locks[client_token]:
|
|
|
- state = await self.get_state(token)
|
|
|
- yield state
|
|
|
- await self.set_state(token, state)
|
|
|
-
|
|
|
-
|
|
|
-def _default_lock_expiration() -> int:
|
|
|
- """Get the default lock expiration time.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The default lock expiration time.
|
|
|
- """
|
|
|
- return get_config().redis_lock_expiration
|
|
|
-
|
|
|
-
|
|
|
-def _default_lock_warning_threshold() -> int:
|
|
|
- """Get the default lock warning threshold.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The default lock warning threshold.
|
|
|
- """
|
|
|
- return get_config().redis_lock_warning_threshold
|
|
|
-
|
|
|
-
|
|
|
-class StateManagerRedis(StateManager):
|
|
|
- """A state manager that stores states in redis."""
|
|
|
-
|
|
|
- # The redis client to use.
|
|
|
- redis: Redis
|
|
|
-
|
|
|
- # The token expiration time (s).
|
|
|
- token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
|
|
|
-
|
|
|
- # The maximum time to hold a lock (ms).
|
|
|
- lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
|
|
|
-
|
|
|
- # The maximum time to hold a lock (ms) before warning.
|
|
|
- lock_warning_threshold: int = pydantic.Field(
|
|
|
- default_factory=_default_lock_warning_threshold
|
|
|
- )
|
|
|
-
|
|
|
- # The keyspace subscription string when redis is waiting for lock to be released.
|
|
|
- _redis_notify_keyspace_events: str = (
|
|
|
- "K" # Enable keyspace notifications (target a particular key)
|
|
|
- "g" # For generic commands (DEL, EXPIRE, etc)
|
|
|
- "x" # For expired events
|
|
|
- "e" # For evicted events (i.e. maxmemory exceeded)
|
|
|
- )
|
|
|
-
|
|
|
- # These events indicate that a lock is no longer held.
|
|
|
- _redis_keyspace_lock_release_events: set[bytes] = {
|
|
|
- b"del",
|
|
|
- b"expire",
|
|
|
- b"expired",
|
|
|
- b"evicted",
|
|
|
- }
|
|
|
-
|
|
|
- # Whether keyspace notifications have been enabled.
|
|
|
- _redis_notify_keyspace_events_enabled: bool = False
|
|
|
-
|
|
|
- # The logical database number used by the redis client.
|
|
|
- _redis_db: int = 0
|
|
|
-
|
|
|
- def _get_required_state_classes(
|
|
|
- self,
|
|
|
- target_state_cls: type[BaseState],
|
|
|
- subclasses: bool = False,
|
|
|
- required_state_classes: set[type[BaseState]] | None = None,
|
|
|
- ) -> set[type[BaseState]]:
|
|
|
- """Recursively determine which states are required to fetch the target state.
|
|
|
-
|
|
|
- This will always include potentially dirty substates that depend on vars
|
|
|
- in the target_state_cls.
|
|
|
-
|
|
|
- Args:
|
|
|
- target_state_cls: The target state class being fetched.
|
|
|
- subclasses: Whether to include subclasses of the target state.
|
|
|
- required_state_classes: Recursive argument tracking state classes that have already been seen.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The set of state classes required to fetch the target state.
|
|
|
- """
|
|
|
- if required_state_classes is None:
|
|
|
- required_state_classes = set()
|
|
|
- # Get the substates if requested.
|
|
|
- if subclasses:
|
|
|
- for substate in target_state_cls.get_substates():
|
|
|
- self._get_required_state_classes(
|
|
|
- substate,
|
|
|
- subclasses=True,
|
|
|
- required_state_classes=required_state_classes,
|
|
|
- )
|
|
|
- if target_state_cls in required_state_classes:
|
|
|
- return required_state_classes
|
|
|
- required_state_classes.add(target_state_cls)
|
|
|
-
|
|
|
- # Get dependent substates.
|
|
|
- for pd_substates in target_state_cls._get_potentially_dirty_states():
|
|
|
- self._get_required_state_classes(
|
|
|
- pd_substates,
|
|
|
- subclasses=False,
|
|
|
- required_state_classes=required_state_classes,
|
|
|
- )
|
|
|
-
|
|
|
- # Get the parent state if it exists.
|
|
|
- if parent_state := target_state_cls.get_parent_state():
|
|
|
- self._get_required_state_classes(
|
|
|
- parent_state,
|
|
|
- subclasses=False,
|
|
|
- required_state_classes=required_state_classes,
|
|
|
- )
|
|
|
- return required_state_classes
|
|
|
-
|
|
|
- def _get_populated_states(
|
|
|
- self,
|
|
|
- target_state: BaseState,
|
|
|
- populated_states: dict[str, BaseState] | None = None,
|
|
|
- ) -> dict[str, BaseState]:
|
|
|
- """Recursively determine which states from target_state are already fetched.
|
|
|
-
|
|
|
- Args:
|
|
|
- target_state: The state to check for populated states.
|
|
|
- populated_states: Recursive argument tracking states seen in previous calls.
|
|
|
-
|
|
|
- Returns:
|
|
|
- A dictionary of state full name to state instance.
|
|
|
- """
|
|
|
- if populated_states is None:
|
|
|
- populated_states = {}
|
|
|
- if target_state.get_full_name() in populated_states:
|
|
|
- return populated_states
|
|
|
- populated_states[target_state.get_full_name()] = target_state
|
|
|
- for substate in target_state.substates.values():
|
|
|
- self._get_populated_states(substate, populated_states=populated_states)
|
|
|
- if target_state.parent_state is not None:
|
|
|
- self._get_populated_states(
|
|
|
- target_state.parent_state, populated_states=populated_states
|
|
|
- )
|
|
|
- return populated_states
|
|
|
-
|
|
|
- @override
|
|
|
- async def get_state(
|
|
|
- self,
|
|
|
- token: str,
|
|
|
- top_level: bool = True,
|
|
|
- for_state_instance: BaseState | None = None,
|
|
|
- ) -> BaseState:
|
|
|
- """Get the state for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to get the state for.
|
|
|
- top_level: If true, return an instance of the top-level state (self.state).
|
|
|
- for_state_instance: If provided, attach the requested states to this existing state tree.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The state for the token.
|
|
|
-
|
|
|
- Raises:
|
|
|
- RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
|
|
|
- requested state was not fetched.
|
|
|
- """
|
|
|
- # Split the actual token from the fully qualified substate name.
|
|
|
- token, state_path = _split_substate_key(token)
|
|
|
- if state_path:
|
|
|
- # Get the State class associated with the given path.
|
|
|
- state_cls = self.state.get_class_substate(state_path)
|
|
|
- else:
|
|
|
- raise RuntimeError(
|
|
|
- f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
|
|
- )
|
|
|
-
|
|
|
- # Determine which states we already have.
|
|
|
- flat_state_tree: dict[str, BaseState] = (
|
|
|
- self._get_populated_states(for_state_instance) if for_state_instance else {}
|
|
|
- )
|
|
|
-
|
|
|
- # Determine which states from the tree need to be fetched.
|
|
|
- required_state_classes = sorted(
|
|
|
- self._get_required_state_classes(state_cls, subclasses=True)
|
|
|
- - {type(s) for s in flat_state_tree.values()},
|
|
|
- key=lambda x: x.get_full_name(),
|
|
|
- )
|
|
|
-
|
|
|
- redis_pipeline = self.redis.pipeline()
|
|
|
- for state_cls in required_state_classes:
|
|
|
- redis_pipeline.get(_substate_key(token, state_cls))
|
|
|
-
|
|
|
- for state_cls, redis_state in zip(
|
|
|
- required_state_classes,
|
|
|
- await redis_pipeline.execute(),
|
|
|
- strict=False,
|
|
|
- ):
|
|
|
- state = None
|
|
|
-
|
|
|
- if redis_state is not None:
|
|
|
- # Deserialize the substate.
|
|
|
- with contextlib.suppress(StateSchemaMismatchError):
|
|
|
- state = BaseState._deserialize(data=redis_state)
|
|
|
- if state is None:
|
|
|
- # Key didn't exist or schema mismatch so create a new instance for this token.
|
|
|
- state = state_cls(
|
|
|
- init_substates=False,
|
|
|
- _reflex_internal_init=True,
|
|
|
- )
|
|
|
- flat_state_tree[state.get_full_name()] = state
|
|
|
- if state.get_parent_state() is not None:
|
|
|
- parent_state_name, _dot, state_name = state.get_full_name().rpartition(
|
|
|
- "."
|
|
|
- )
|
|
|
- parent_state = flat_state_tree.get(parent_state_name)
|
|
|
- if parent_state is None:
|
|
|
- raise RuntimeError(
|
|
|
- f"Parent state for {state.get_full_name()} was not found "
|
|
|
- "in the state tree, but should have already been fetched. "
|
|
|
- "This is a bug",
|
|
|
- )
|
|
|
- parent_state.substates[state_name] = state
|
|
|
- state.parent_state = parent_state
|
|
|
-
|
|
|
- # To retain compatibility with previous implementation, by default, we return
|
|
|
- # the top-level state which should always be fetched or already cached.
|
|
|
- if top_level:
|
|
|
- return flat_state_tree[self.state.get_full_name()]
|
|
|
- return flat_state_tree[state_cls.get_full_name()]
|
|
|
-
|
|
|
- @override
|
|
|
- async def set_state(
|
|
|
- self,
|
|
|
- token: str,
|
|
|
- state: BaseState,
|
|
|
- lock_id: bytes | None = None,
|
|
|
- ):
|
|
|
- """Set the state for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to set the state for.
|
|
|
- state: The state to set.
|
|
|
- lock_id: If provided, the lock_key must be set to this value to set the state.
|
|
|
-
|
|
|
- Raises:
|
|
|
- LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
|
|
|
- RuntimeError: If the state instance doesn't match the state name in the token.
|
|
|
- """
|
|
|
- # Check that we're holding the lock.
|
|
|
- if (
|
|
|
- lock_id is not None
|
|
|
- and await self.redis.get(self._lock_key(token)) != lock_id
|
|
|
- ):
|
|
|
- raise LockExpiredError(
|
|
|
- f"Lock expired for token {token} while processing. Consider increasing "
|
|
|
- f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
|
|
- "or use `@rx.event(background=True)` decorator for long-running tasks."
|
|
|
- )
|
|
|
- elif lock_id is not None:
|
|
|
- time_taken = self.lock_expiration / 1000 - (
|
|
|
- await self.redis.ttl(self._lock_key(token))
|
|
|
- )
|
|
|
- if time_taken > self.lock_warning_threshold / 1000:
|
|
|
- console.warn(
|
|
|
- f"Lock for token {token} was held too long {time_taken=}s, "
|
|
|
- f"use `@rx.event(background=True)` decorator for long-running tasks.",
|
|
|
- dedupe=True,
|
|
|
- )
|
|
|
-
|
|
|
- client_token, substate_name = _split_substate_key(token)
|
|
|
- # If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
|
|
- if state.parent_state is not None and state.get_full_name() != substate_name:
|
|
|
- raise RuntimeError(
|
|
|
- f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
|
|
|
- )
|
|
|
-
|
|
|
- # Recursively set_state on all known substates.
|
|
|
- tasks = [
|
|
|
- asyncio.create_task(
|
|
|
- self.set_state(
|
|
|
- _substate_key(client_token, substate),
|
|
|
- substate,
|
|
|
- lock_id,
|
|
|
- )
|
|
|
- )
|
|
|
- for substate in state.substates.values()
|
|
|
- ]
|
|
|
- # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
|
|
- if state._get_was_touched():
|
|
|
- pickle_state = state._serialize()
|
|
|
- if pickle_state:
|
|
|
- await self.redis.set(
|
|
|
- _substate_key(client_token, state),
|
|
|
- pickle_state,
|
|
|
- ex=self.token_expiration,
|
|
|
- )
|
|
|
-
|
|
|
- # Wait for substates to be persisted.
|
|
|
- for t in tasks:
|
|
|
- await t
|
|
|
-
|
|
|
- @override
|
|
|
- @contextlib.asynccontextmanager
|
|
|
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
|
|
- """Modify the state for a token while holding exclusive lock.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to modify the state for.
|
|
|
-
|
|
|
- Yields:
|
|
|
- The state for the token.
|
|
|
- """
|
|
|
- async with self._lock(token) as lock_id:
|
|
|
- state = await self.get_state(token)
|
|
|
- yield state
|
|
|
- await self.set_state(token, state, lock_id)
|
|
|
-
|
|
|
- @validator("lock_warning_threshold")
|
|
|
- @classmethod
|
|
|
- def validate_lock_warning_threshold(
|
|
|
- cls, lock_warning_threshold: int, values: dict[str, int]
|
|
|
- ):
|
|
|
- """Validate the lock warning threshold.
|
|
|
-
|
|
|
- Args:
|
|
|
- lock_warning_threshold: The lock warning threshold.
|
|
|
- values: The validated attributes.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The lock warning threshold.
|
|
|
-
|
|
|
- Raises:
|
|
|
- InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
|
|
|
- """
|
|
|
- if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
|
|
|
- raise InvalidLockWarningThresholdError(
|
|
|
- f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
|
|
- )
|
|
|
- return lock_warning_threshold
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def _lock_key(token: str) -> bytes:
|
|
|
- """Get the redis key for a token's lock.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to get the lock key for.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The redis lock key for the token.
|
|
|
- """
|
|
|
- # All substates share the same lock domain, so ignore any substate path suffix.
|
|
|
- client_token = _split_substate_key(token)[0]
|
|
|
- return f"{client_token}_lock".encode()
|
|
|
-
|
|
|
- async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
|
|
|
- """Try to get a redis lock for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- lock_key: The redis key for the lock.
|
|
|
- lock_id: The ID of the lock.
|
|
|
-
|
|
|
- Returns:
|
|
|
- True if the lock was obtained.
|
|
|
- """
|
|
|
- return await self.redis.set(
|
|
|
- lock_key,
|
|
|
- lock_id,
|
|
|
- px=self.lock_expiration,
|
|
|
- nx=True, # only set if it doesn't exist
|
|
|
- )
|
|
|
-
|
|
|
- async def _get_pubsub_message(
|
|
|
- self, pubsub: PubSub, timeout: float | None = None
|
|
|
- ) -> None:
|
|
|
- """Get lock release events from the pubsub.
|
|
|
-
|
|
|
- Args:
|
|
|
- pubsub: The pubsub to get a message from.
|
|
|
- timeout: Remaining time to wait for a message.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The message.
|
|
|
- """
|
|
|
- if timeout is None:
|
|
|
- timeout = self.lock_expiration / 1000.0
|
|
|
-
|
|
|
- started = time.time()
|
|
|
- message = await pubsub.get_message(
|
|
|
- ignore_subscribe_messages=True,
|
|
|
- timeout=timeout,
|
|
|
- )
|
|
|
- if (
|
|
|
- message is None
|
|
|
- or message["data"] not in self._redis_keyspace_lock_release_events
|
|
|
- ):
|
|
|
- remaining = timeout - (time.time() - started)
|
|
|
- if remaining <= 0:
|
|
|
- return
|
|
|
- await self._get_pubsub_message(pubsub, timeout=remaining)
|
|
|
-
|
|
|
- async def _enable_keyspace_notifications(self):
|
|
|
- """Enable keyspace notifications for the redis server.
|
|
|
-
|
|
|
- Raises:
|
|
|
- ResponseError: when the keyspace config cannot be set.
|
|
|
- """
|
|
|
- if self._redis_notify_keyspace_events_enabled:
|
|
|
- return
|
|
|
- # Find out which logical database index is being used.
|
|
|
- self._redis_db = self.redis.get_connection_kwargs().get("db", self._redis_db)
|
|
|
-
|
|
|
- try:
|
|
|
- await self.redis.config_set(
|
|
|
- "notify-keyspace-events",
|
|
|
- self._redis_notify_keyspace_events,
|
|
|
- )
|
|
|
- except ResponseError:
|
|
|
- # Some redis servers only allow out-of-band configuration, so ignore errors here.
|
|
|
- if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
|
|
|
- raise
|
|
|
- self._redis_notify_keyspace_events_enabled = True
|
|
|
-
|
|
|
- async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
|
|
- """Wait for a redis lock to be released via pubsub.
|
|
|
-
|
|
|
- Coroutine will not return until the lock is obtained.
|
|
|
-
|
|
|
- Args:
|
|
|
- lock_key: The redis key for the lock.
|
|
|
- lock_id: The ID of the lock.
|
|
|
- """
|
|
|
- # Enable keyspace notifications for the lock key, so we know when it is available.
|
|
|
- await self._enable_keyspace_notifications()
|
|
|
- lock_key_channel = f"__keyspace@{self._redis_db}__:{lock_key.decode()}"
|
|
|
- async with self.redis.pubsub() as pubsub:
|
|
|
- await pubsub.psubscribe(lock_key_channel)
|
|
|
- # wait for the lock to be released
|
|
|
- while True:
|
|
|
- # fast path
|
|
|
- if await self._try_get_lock(lock_key, lock_id):
|
|
|
- return
|
|
|
- # wait for lock events
|
|
|
- await self._get_pubsub_message(pubsub)
|
|
|
-
|
|
|
- @contextlib.asynccontextmanager
|
|
|
- async def _lock(self, token: str):
|
|
|
- """Obtain a redis lock for a token.
|
|
|
-
|
|
|
- Args:
|
|
|
- token: The token to obtain a lock for.
|
|
|
-
|
|
|
- Yields:
|
|
|
- The ID of the lock (to be passed to set_state).
|
|
|
-
|
|
|
- Raises:
|
|
|
- LockExpiredError: If the lock has expired while processing the event.
|
|
|
- """
|
|
|
- lock_key = self._lock_key(token)
|
|
|
- lock_id = uuid.uuid4().hex.encode()
|
|
|
-
|
|
|
- if not await self._try_get_lock(lock_key, lock_id):
|
|
|
- # Missed the fast-path to get lock, subscribe for lock delete/expire events
|
|
|
- await self._wait_lock(lock_key, lock_id)
|
|
|
- state_is_locked = True
|
|
|
-
|
|
|
- try:
|
|
|
- yield lock_id
|
|
|
- except LockExpiredError:
|
|
|
- state_is_locked = False
|
|
|
- raise
|
|
|
- finally:
|
|
|
- if state_is_locked:
|
|
|
- # only delete our lock
|
|
|
- await self.redis.delete(lock_key)
|
|
|
-
|
|
|
- async def close(self):
|
|
|
- """Explicitly close the redis connection and connection_pool.
|
|
|
-
|
|
|
- It is necessary in testing scenarios to close between asyncio test cases
|
|
|
- to avoid having lingering redis connections associated with event loops
|
|
|
- that will be closed (each test case uses its own event loop).
|
|
|
-
|
|
|
- Note: Connections will be automatically reopened when needed.
|
|
|
- """
|
|
|
- await self.redis.aclose(close_connection_pool=True)
|
|
|
-
|
|
|
-
|
|
|
-def get_state_manager() -> StateManager:
|
|
|
- """Get the state manager for the app that is currently running.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The state manager.
|
|
|
- """
|
|
|
- return prerequisites.get_and_validate_app().app.state_manager
|
|
|
-
|
|
|
-
|
|
|
-class MutableProxy(wrapt.ObjectProxy):
|
|
|
- """A proxy for a mutable object that tracks changes."""
|
|
|
-
|
|
|
- # Hint for finding the base class of the proxy.
|
|
|
- __base_proxy__ = "MutableProxy"
|
|
|
-
|
|
|
- # Methods on wrapped objects which should mark the state as dirty.
|
|
|
- __mark_dirty_attrs__ = {
|
|
|
- "add",
|
|
|
- "append",
|
|
|
- "clear",
|
|
|
- "difference_update",
|
|
|
- "discard",
|
|
|
- "extend",
|
|
|
- "insert",
|
|
|
- "intersection_update",
|
|
|
- "pop",
|
|
|
- "popitem",
|
|
|
- "remove",
|
|
|
- "reverse",
|
|
|
- "setdefault",
|
|
|
- "sort",
|
|
|
- "symmetric_difference_update",
|
|
|
- "update",
|
|
|
- }
|
|
|
-
|
|
|
- # Methods on wrapped objects might return mutable objects that should be tracked.
|
|
|
- __wrap_mutable_attrs__ = {
|
|
|
- "get",
|
|
|
- "setdefault",
|
|
|
- }
|
|
|
-
|
|
|
- # These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
|
|
|
- __never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
|
|
|
- pydantic.BaseModel.__dict__
|
|
|
- )
|
|
|
-
|
|
|
- # These types will be wrapped in MutableProxy
|
|
|
- __mutable_types__ = (
|
|
|
- list,
|
|
|
- dict,
|
|
|
- set,
|
|
|
- Base,
|
|
|
- DeclarativeBase,
|
|
|
- BaseModelV2,
|
|
|
- BaseModelV1,
|
|
|
- )
|
|
|
-
|
|
|
- # Dynamically generated classes for tracking dataclass mutations.
|
|
|
- __dataclass_proxies__: dict[type, type] = {}
|
|
|
-
|
|
|
- def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
|
|
|
- """Create a proxy instance for a mutable object that tracks changes.
|
|
|
-
|
|
|
- Args:
|
|
|
- wrapped: The object to proxy.
|
|
|
- *args: Other args passed to MutableProxy (ignored).
|
|
|
- **kwargs: Other kwargs passed to MutableProxy (ignored).
|
|
|
-
|
|
|
- Returns:
|
|
|
- The proxy instance.
|
|
|
- """
|
|
|
- if dataclasses.is_dataclass(wrapped):
|
|
|
- wrapped_cls = type(wrapped)
|
|
|
- wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
|
|
|
- # Find the associated class
|
|
|
- if wrapper_cls_name not in cls.__dataclass_proxies__:
|
|
|
- # Create a new class that has the __dataclass_fields__ defined
|
|
|
- cls.__dataclass_proxies__[wrapper_cls_name] = type(
|
|
|
- wrapper_cls_name,
|
|
|
- (cls,),
|
|
|
- {
|
|
|
- dataclasses._FIELDS: getattr( # pyright: ignore [reportAttributeAccessIssue]
|
|
|
- wrapped_cls,
|
|
|
- dataclasses._FIELDS, # pyright: ignore [reportAttributeAccessIssue]
|
|
|
- ),
|
|
|
- },
|
|
|
- )
|
|
|
- cls = cls.__dataclass_proxies__[wrapper_cls_name]
|
|
|
- return super().__new__(cls)
|
|
|
-
|
|
|
- def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
|
|
- """Create a proxy for a mutable object that tracks changes.
|
|
|
-
|
|
|
- Args:
|
|
|
- wrapped: The object to proxy.
|
|
|
- state: The state to mark dirty when the object is changed.
|
|
|
- field_name: The name of the field on the state associated with the
|
|
|
- wrapped object.
|
|
|
- """
|
|
|
- super().__init__(wrapped)
|
|
|
- self._self_state = state
|
|
|
- self._self_field_name = field_name
|
|
|
-
|
|
|
- def __repr__(self) -> str:
|
|
|
- """Get the representation of the wrapped object.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The representation of the wrapped object.
|
|
|
- """
|
|
|
- return f"{type(self).__name__}({self.__wrapped__})"
|
|
|
-
|
|
|
- def _mark_dirty(
|
|
|
- self,
|
|
|
- wrapped: Callable | None = None,
|
|
|
- instance: BaseState | None = None,
|
|
|
- args: tuple = (),
|
|
|
- kwargs: dict | None = None,
|
|
|
- ) -> Any:
|
|
|
- """Mark the state as dirty, then call a wrapped function.
|
|
|
-
|
|
|
- Intended for use with `FunctionWrapper` from the `wrapt` library.
|
|
|
-
|
|
|
- Args:
|
|
|
- wrapped: The wrapped function.
|
|
|
- instance: The instance of the wrapped function.
|
|
|
- args: The args for the wrapped function.
|
|
|
- kwargs: The kwargs for the wrapped function.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The result of the wrapped function.
|
|
|
- """
|
|
|
- self._self_state.dirty_vars.add(self._self_field_name)
|
|
|
- self._self_state._mark_dirty()
|
|
|
- if wrapped is not None:
|
|
|
- return wrapped(*args, **(kwargs or {}))
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def _is_mutable_type(cls, value: Any) -> bool:
|
|
|
- """Check if a value is of a mutable type and should be wrapped.
|
|
|
-
|
|
|
- Args:
|
|
|
- value: The value to check.
|
|
|
-
|
|
|
- Returns:
|
|
|
- Whether the value is of a mutable type.
|
|
|
- """
|
|
|
- return isinstance(value, cls.__mutable_types__) or (
|
|
|
- dataclasses.is_dataclass(value) and not isinstance(value, Var)
|
|
|
- )
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def _is_called_from_dataclasses_internal() -> bool:
|
|
|
- """Check if the current function is called from dataclasses helper.
|
|
|
-
|
|
|
- Returns:
|
|
|
- Whether the current function is called from dataclasses internal code.
|
|
|
- """
|
|
|
- # Walk up the stack a bit to see if we are called from dataclasses
|
|
|
- # internal code, for example `asdict` or `astuple`.
|
|
|
- frame = inspect.currentframe()
|
|
|
- for _ in range(5):
|
|
|
- # Why not `inspect.stack()` -- this is much faster!
|
|
|
- if not (frame := frame and frame.f_back):
|
|
|
- break
|
|
|
- if inspect.getfile(frame) == dataclasses.__file__:
|
|
|
- return True
|
|
|
- return False
|
|
|
-
|
|
|
- def _wrap_recursive(self, value: Any) -> Any:
|
|
|
- """Wrap a value recursively if it is mutable.
|
|
|
-
|
|
|
- Args:
|
|
|
- value: The value to wrap.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The wrapped value.
|
|
|
- """
|
|
|
- # When called from dataclasses internal code, return the unwrapped value
|
|
|
- if self._is_called_from_dataclasses_internal():
|
|
|
- return value
|
|
|
- # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
|
|
|
- if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
|
|
|
- base_cls = globals()[self.__base_proxy__]
|
|
|
- return base_cls(
|
|
|
- wrapped=value,
|
|
|
- state=self._self_state,
|
|
|
- field_name=self._self_field_name,
|
|
|
- )
|
|
|
- return value
|
|
|
-
|
|
|
- def _wrap_recursive_decorator(
|
|
|
- self, wrapped: Callable, instance: BaseState, args: list, kwargs: dict
|
|
|
- ) -> Any:
|
|
|
- """Wrap a function that returns a possibly mutable value.
|
|
|
-
|
|
|
- Intended for use with `FunctionWrapper` from the `wrapt` library.
|
|
|
-
|
|
|
- Args:
|
|
|
- wrapped: The wrapped function.
|
|
|
- instance: The instance of the wrapped function.
|
|
|
- args: The args for the wrapped function.
|
|
|
- kwargs: The kwargs for the wrapped function.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The result of the wrapped function (possibly wrapped in a MutableProxy).
|
|
|
- """
|
|
|
- return self._wrap_recursive(wrapped(*args, **kwargs))
|
|
|
-
|
|
|
- def __getattr__(self, __name: str) -> Any:
|
|
|
- """Get the attribute on the proxied object and return a proxy if mutable.
|
|
|
-
|
|
|
- Args:
|
|
|
- __name: The name of the attribute.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The attribute value.
|
|
|
- """
|
|
|
- value = super().__getattr__(__name)
|
|
|
-
|
|
|
- if callable(value):
|
|
|
- if __name in self.__mark_dirty_attrs__:
|
|
|
- # Wrap special callables, like "append", which should mark state dirty.
|
|
|
- value = wrapt.FunctionWrapper(value, self._mark_dirty)
|
|
|
-
|
|
|
- if __name in self.__wrap_mutable_attrs__:
|
|
|
- # Wrap methods that may return mutable objects tied to the state.
|
|
|
- value = wrapt.FunctionWrapper(
|
|
|
- value,
|
|
|
- self._wrap_recursive_decorator,
|
|
|
- )
|
|
|
-
|
|
|
- if (
|
|
|
- isinstance(self.__wrapped__, Base)
|
|
|
- and __name not in self.__never_wrap_base_attrs__
|
|
|
- and hasattr(value, "__func__")
|
|
|
- ):
|
|
|
- # Wrap methods called on Base subclasses, which might do _anything_
|
|
|
- return wrapt.FunctionWrapper(
|
|
|
- functools.partial(value.__func__, self), # pyright: ignore [reportFunctionMemberAccess]
|
|
|
- self._wrap_recursive_decorator,
|
|
|
- )
|
|
|
-
|
|
|
- if self._is_mutable_type(value) and __name not in (
|
|
|
- "__wrapped__",
|
|
|
- "_self_state",
|
|
|
- "__dict__",
|
|
|
- ):
|
|
|
- # Recursively wrap mutable attribute values retrieved through this proxy.
|
|
|
- return self._wrap_recursive(value)
|
|
|
-
|
|
|
- return value
|
|
|
-
|
|
|
- def __getitem__(self, key: Any) -> Any:
|
|
|
- """Get the item on the proxied object and return a proxy if mutable.
|
|
|
-
|
|
|
- Args:
|
|
|
- key: The key of the item.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The item value.
|
|
|
- """
|
|
|
- value = super().__getitem__(key)
|
|
|
- if isinstance(key, slice) and isinstance(value, list):
|
|
|
- return [self._wrap_recursive(item) for item in value]
|
|
|
- # Recursively wrap mutable items retrieved through this proxy.
|
|
|
- return self._wrap_recursive(value)
|
|
|
-
|
|
|
- def __iter__(self) -> Any:
|
|
|
- """Iterate over the proxied object and return a proxy if mutable.
|
|
|
-
|
|
|
- Yields:
|
|
|
- Each item value (possibly wrapped in MutableProxy).
|
|
|
- """
|
|
|
- for value in super().__iter__():
|
|
|
- # Recursively wrap mutable items retrieved through this proxy.
|
|
|
- yield self._wrap_recursive(value)
|
|
|
-
|
|
|
- def __delattr__(self, name: str):
|
|
|
- """Delete the attribute on the proxied object and mark state dirty.
|
|
|
-
|
|
|
- Args:
|
|
|
- name: The name of the attribute.
|
|
|
- """
|
|
|
- self._mark_dirty(super().__delattr__, args=(name,))
|
|
|
-
|
|
|
- def __delitem__(self, key: str):
|
|
|
- """Delete the item on the proxied object and mark state dirty.
|
|
|
-
|
|
|
- Args:
|
|
|
- key: The key of the item.
|
|
|
- """
|
|
|
- self._mark_dirty(super().__delitem__, args=(key,))
|
|
|
-
|
|
|
- def __setitem__(self, key: str, value: Any):
|
|
|
- """Set the item on the proxied object and mark state dirty.
|
|
|
-
|
|
|
- Args:
|
|
|
- key: The key of the item.
|
|
|
- value: The value of the item.
|
|
|
- """
|
|
|
- self._mark_dirty(super().__setitem__, args=(key, value))
|
|
|
-
|
|
|
- def __setattr__(self, name: str, value: Any):
|
|
|
- """Set the attribute on the proxied object and mark state dirty.
|
|
|
-
|
|
|
- If the attribute starts with "_self_", then the state is NOT marked
|
|
|
- dirty as these are internal proxy attributes.
|
|
|
-
|
|
|
- Args:
|
|
|
- name: The name of the attribute.
|
|
|
- value: The value of the attribute.
|
|
|
- """
|
|
|
- if name.startswith("_self_"):
|
|
|
- # Special case attributes of the proxy itself, not applied to the wrapped object.
|
|
|
- super().__setattr__(name, value)
|
|
|
- return
|
|
|
- self._mark_dirty(super().__setattr__, args=(name, value))
|
|
|
-
|
|
|
- def __copy__(self) -> Any:
|
|
|
- """Return a copy of the proxy.
|
|
|
-
|
|
|
- Returns:
|
|
|
- A copy of the wrapped object, unconnected to the proxy.
|
|
|
- """
|
|
|
- return copy.copy(self.__wrapped__)
|
|
|
-
|
|
|
- def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Any:
|
|
|
- """Return a deepcopy of the proxy.
|
|
|
-
|
|
|
- Args:
|
|
|
- memo: The memo dict to use for the deepcopy.
|
|
|
-
|
|
|
- Returns:
|
|
|
- A deepcopy of the wrapped object, unconnected to the proxy.
|
|
|
- """
|
|
|
- return copy.deepcopy(self.__wrapped__, memo=memo)
|
|
|
-
|
|
|
- def __reduce_ex__(self, protocol_version: SupportsIndex):
|
|
|
- """Get the state for redis serialization.
|
|
|
-
|
|
|
- This method is called by cloudpickle to serialize the object.
|
|
|
-
|
|
|
- It explicitly serializes the wrapped object, stripping off the mutable proxy.
|
|
|
-
|
|
|
- Args:
|
|
|
- protocol_version: The protocol version.
|
|
|
-
|
|
|
- Returns:
|
|
|
- Tuple of (wrapped class, empty args, class __getstate__)
|
|
|
- """
|
|
|
- return self.__wrapped__.__reduce_ex__(protocol_version)
|
|
|
-
|
|
|
-
|
|
|
-@serializer
|
|
|
-def serialize_mutable_proxy(mp: MutableProxy):
|
|
|
- """Return the wrapped value of a MutableProxy.
|
|
|
-
|
|
|
- Args:
|
|
|
- mp: The MutableProxy to serialize.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The wrapped object.
|
|
|
- """
|
|
|
- return mp.__wrapped__
|
|
|
-
|
|
|
-
|
|
|
-_orig_json_encoder_default = json.JSONEncoder.default
|
|
|
-
|
|
|
-
|
|
|
-def _json_encoder_default_wrapper(self: json.JSONEncoder, o: Any) -> Any:
|
|
|
- """Wrap JSONEncoder.default to handle MutableProxy objects.
|
|
|
-
|
|
|
- Args:
|
|
|
- self: the JSONEncoder instance.
|
|
|
- o: the object to serialize.
|
|
|
-
|
|
|
- Returns:
|
|
|
- A JSON-able object.
|
|
|
- """
|
|
|
- try:
|
|
|
- return o.__wrapped__
|
|
|
- except AttributeError:
|
|
|
- pass
|
|
|
- return _orig_json_encoder_default(self, o)
|
|
|
-
|
|
|
-
|
|
|
-json.JSONEncoder.default = _json_encoder_default_wrapper
|
|
|
-
|
|
|
-
|
|
|
-class ImmutableMutableProxy(MutableProxy):
|
|
|
- """A proxy for a mutable object that tracks changes.
|
|
|
-
|
|
|
- This wrapper comes from StateProxy, and will raise an exception if an attempt is made
|
|
|
- to modify the wrapped object when the StateProxy is immutable.
|
|
|
- """
|
|
|
-
|
|
|
- # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
|
|
|
- __base_proxy__ = "ImmutableMutableProxy"
|
|
|
-
|
|
|
- def _mark_dirty(
|
|
|
- self,
|
|
|
- wrapped: Callable | None = None,
|
|
|
- instance: BaseState | None = None,
|
|
|
- args: tuple = (),
|
|
|
- kwargs: dict | None = None,
|
|
|
- ) -> Any:
|
|
|
- """Raise an exception when an attempt is made to modify the object.
|
|
|
-
|
|
|
- Intended for use with `FunctionWrapper` from the `wrapt` library.
|
|
|
-
|
|
|
- Args:
|
|
|
- wrapped: The wrapped function.
|
|
|
- instance: The instance of the wrapped function.
|
|
|
- args: The args for the wrapped function.
|
|
|
- kwargs: The kwargs for the wrapped function.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The result of the wrapped function.
|
|
|
-
|
|
|
- Raises:
|
|
|
- ImmutableStateError: if the StateProxy is not mutable.
|
|
|
- """
|
|
|
- if not self._self_state._is_mutable():
|
|
|
- raise ImmutableStateError(
|
|
|
- "Background task StateProxy is immutable outside of a context "
|
|
|
- "manager. Use `async with self` to modify state."
|
|
|
- )
|
|
|
- return super()._mark_dirty(
|
|
|
- wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-def code_uses_state_contexts(javascript_code: str) -> bool:
|
|
|
- """Check if the rendered Javascript uses state contexts.
|
|
|
-
|
|
|
- Args:
|
|
|
- javascript_code: The Javascript code to check.
|
|
|
-
|
|
|
- Returns:
|
|
|
- True if the code attempts to access a member of StateContexts.
|
|
|
- """
|
|
|
- return bool("useContext(StateContexts" in javascript_code)
|
|
|
-
|
|
|
-
|
|
|
-def reload_state_module(
|
|
|
- module: str,
|
|
|
- state: type[BaseState] = State,
|
|
|
-) -> None:
|
|
|
- """Reset rx.State subclasses to avoid conflict when reloading.
|
|
|
-
|
|
|
- Args:
|
|
|
- module: The module to reload.
|
|
|
- state: Recursive argument for the state class to reload.
|
|
|
-
|
|
|
- """
|
|
|
- # Clean out all potentially dirty states of reloaded modules.
|
|
|
- for pd_state in tuple(state._potentially_dirty_states):
|
|
|
- with contextlib.suppress(ValueError):
|
|
|
- if (
|
|
|
- state.get_root_state().get_class_substate(pd_state).__module__ == module
|
|
|
- and module is not None
|
|
|
- ):
|
|
|
- state._potentially_dirty_states.remove(pd_state)
|
|
|
- for subclass in tuple(state.class_subclasses):
|
|
|
- reload_state_module(module=module, state=subclass)
|
|
|
- if subclass.__module__ == module and module is not None:
|
|
|
- all_base_state_classes.pop(subclass.get_full_name(), None)
|
|
|
- state.class_subclasses.remove(subclass)
|
|
|
- state._always_dirty_substates.discard(subclass.get_name())
|
|
|
- state._var_dependencies = {}
|
|
|
- state._init_var_dependency_dicts()
|
|
|
- state.get_class_substate.cache_clear()
|
|
|
+from reflex.istate.manager import LockExpiredError as LockExpiredError # noqa: E402
|
|
|
+from reflex.istate.manager import StateManager as StateManager # noqa: E402
|
|
|
+from reflex.istate.manager import StateManagerDisk as StateManagerDisk # noqa: E402
|
|
|
+from reflex.istate.manager import StateManagerMemory as StateManagerMemory # noqa: E402
|
|
|
+from reflex.istate.manager import StateManagerRedis as StateManagerRedis # noqa: E402
|
|
|
+from reflex.istate.manager import get_state_manager as get_state_manager # noqa: E402
|
|
|
+from reflex.istate.manager import ( # noqa: E402
|
|
|
+ reset_disk_state_manager as reset_disk_state_manager,
|
|
|
+)
|