1
0
Эх сурвалжийг харах

move proxy classes into proxy.py (#5224)

* move proxy classes into proxy.py

* delete mutable proxy from state.py

* merge slicing a list back into new code

* move managers to a file

* remove pydantic from state managers

* import override from typing extension

* simplify some imports

* simplify even more

* fix the validator
Khaleel Al-Adhami 2 долоо хоног өмнө
parent
commit
d18989e622

+ 858 - 0
reflex/istate/manager.py

@@ -0,0 +1,858 @@
+"""State manager for managing client states."""
+
+import asyncio
+import contextlib
+import dataclasses
+import functools
+import time
+import uuid
+from abc import ABC, abstractmethod
+from collections.abc import AsyncIterator
+from hashlib import md5
+from pathlib import Path
+
+from redis import ResponseError
+from redis.asyncio import Redis
+from redis.asyncio.client import PubSub
+from typing_extensions import override
+
+from reflex import constants
+from reflex.config import environment, get_config
+from reflex.state import BaseState, _split_substate_key, _substate_key
+from reflex.utils import console, path_ops, prerequisites
+from reflex.utils.exceptions import (
+    InvalidLockWarningThresholdError,
+    InvalidStateManagerModeError,
+    LockExpiredError,
+    StateSchemaMismatchError,
+)
+
+
+@dataclasses.dataclass
+class StateManager(ABC):
+    """A class to manage many client states."""
+
+    # The state class to use.
+    state: type[BaseState]
+
+    @classmethod
+    def create(cls, state: type[BaseState]):
+        """Create a new state manager.
+
+        Args:
+            state: The state class to use.
+
+        Raises:
+            InvalidStateManagerModeError: If the state manager mode is invalid.
+
+        Returns:
+            The state manager (either disk, memory or redis).
+        """
+        config = get_config()
+        if prerequisites.parse_redis_url() is not None:
+            config.state_manager_mode = constants.StateManagerMode.REDIS
+        if config.state_manager_mode == constants.StateManagerMode.MEMORY:
+            return StateManagerMemory(state=state)
+        if config.state_manager_mode == constants.StateManagerMode.DISK:
+            return StateManagerDisk(state=state)
+        if config.state_manager_mode == constants.StateManagerMode.REDIS:
+            redis = prerequisites.get_redis()
+            if redis is not None:
+                # make sure expiration values are obtained only from the config object on creation
+                return StateManagerRedis(
+                    state=state,
+                    redis=redis,
+                    token_expiration=config.redis_token_expiration,
+                    lock_expiration=config.redis_lock_expiration,
+                    lock_warning_threshold=config.redis_lock_warning_threshold,
+                )
+        raise InvalidStateManagerModeError(
+            f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
+        )
+
+    @abstractmethod
+    async def get_state(self, token: str) -> BaseState:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        pass
+
+    @abstractmethod
+    async def set_state(self, token: str, state: BaseState):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+        """
+        pass
+
+    @abstractmethod
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        yield self.state()
+
+
+@dataclasses.dataclass
+class StateManagerMemory(StateManager):
+    """A state manager that stores states in memory."""
+
+    # The mapping of client ids to states.
+    states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
+
+    # The mutex ensures the dict of mutexes is updated exclusively
+    _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
+
+    # The dict of mutexes for each client
+    _states_locks: dict[str, asyncio.Lock] = dataclasses.field(
+        default_factory=dict, init=False
+    )
+
+    @override
+    async def get_state(self, token: str) -> BaseState:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        # Memory state manager ignores the substate suffix and always returns the top-level state.
+        token = _split_substate_key(token)[0]
+        if token not in self.states:
+            self.states[token] = self.state(_reflex_internal_init=True)
+        return self.states[token]
+
+    @override
+    async def set_state(self, token: str, state: BaseState):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+        """
+        pass
+
+    @override
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        # Memory state manager ignores the substate suffix and always returns the top-level state.
+        token = _split_substate_key(token)[0]
+        if token not in self._states_locks:
+            async with self._state_manager_lock:
+                if token not in self._states_locks:
+                    self._states_locks[token] = asyncio.Lock()
+
+        async with self._states_locks[token]:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state)
+
+
+def _default_token_expiration() -> int:
+    """Get the default token expiration time.
+
+    Returns:
+        The default token expiration time.
+    """
+    return get_config().redis_token_expiration
+
+
+def reset_disk_state_manager():
+    """Reset the disk state manager."""
+    states_directory = prerequisites.get_states_dir()
+    if states_directory.exists():
+        for path in states_directory.iterdir():
+            path.unlink()
+
+
+@dataclasses.dataclass
+class StateManagerDisk(StateManager):
+    """A state manager that stores states in memory."""
+
+    # The mapping of client ids to states.
+    states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
+
+    # The mutex ensures the dict of mutexes is updated exclusively
+    _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
+
+    # The dict of mutexes for each client
+    _states_locks: dict[str, asyncio.Lock] = dataclasses.field(
+        default_factory=dict,
+        init=False,
+    )
+
+    # The token expiration time (s).
+    token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
+
+    def __post_init_(self):
+        """Create a new state manager."""
+        path_ops.mkdir(self.states_directory)
+
+        self._purge_expired_states()
+
+    @functools.cached_property
+    def states_directory(self) -> Path:
+        """Get the states directory.
+
+        Returns:
+            The states directory.
+        """
+        return prerequisites.get_states_dir()
+
+    def _purge_expired_states(self):
+        """Purge expired states from the disk."""
+        import time
+
+        for path in path_ops.ls(self.states_directory):
+            # check path is a pickle file
+            if path.suffix != ".pkl":
+                continue
+
+            # load last edited field from file
+            last_edited = path.stat().st_mtime
+
+            # check if the file is older than the token expiration time
+            if time.time() - last_edited > self.token_expiration:
+                # remove the file
+                path.unlink()
+
+    def token_path(self, token: str) -> Path:
+        """Get the path for a token.
+
+        Args:
+            token: The token to get the path for.
+
+        Returns:
+            The path for the token.
+        """
+        return (
+            self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
+        ).absolute()
+
+    async def load_state(self, token: str) -> BaseState | None:
+        """Load a state object based on the provided token.
+
+        Args:
+            token: The token used to identify the state object.
+
+        Returns:
+            The loaded state object or None.
+        """
+        token_path = self.token_path(token)
+
+        if token_path.exists():
+            try:
+                with token_path.open(mode="rb") as file:
+                    return BaseState._deserialize(fp=file)
+            except Exception:
+                pass
+
+    async def populate_substates(
+        self, client_token: str, state: BaseState, root_state: BaseState
+    ):
+        """Populate the substates of a state object.
+
+        Args:
+            client_token: The client token.
+            state: The state object to populate.
+            root_state: The root state object.
+        """
+        for substate in state.get_substates():
+            substate_token = _substate_key(client_token, substate)
+
+            fresh_instance = await root_state.get_state(substate)
+            instance = await self.load_state(substate_token)
+            if instance is not None:
+                # Ensure all substates exist, even if they weren't serialized previously.
+                instance.substates = fresh_instance.substates
+            else:
+                instance = fresh_instance
+            state.substates[substate.get_name()] = instance
+            instance.parent_state = state
+
+            await self.populate_substates(client_token, instance, root_state)
+
+    @override
+    async def get_state(
+        self,
+        token: str,
+    ) -> BaseState:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        client_token = _split_substate_key(token)[0]
+        root_state = self.states.get(client_token)
+        if root_state is not None:
+            # Retrieved state from memory.
+            return root_state
+
+        # Deserialize root state from disk.
+        root_state = await self.load_state(_substate_key(client_token, self.state))
+        # Create a new root state tree with all substates instantiated.
+        fresh_root_state = self.state(_reflex_internal_init=True)
+        if root_state is None:
+            root_state = fresh_root_state
+        else:
+            # Ensure all substates exist, even if they were not serialized previously.
+            root_state.substates = fresh_root_state.substates
+        self.states[client_token] = root_state
+        await self.populate_substates(client_token, root_state, root_state)
+        return root_state
+
+    async def set_state_for_substate(self, client_token: str, substate: BaseState):
+        """Set the state for a substate.
+
+        Args:
+            client_token: The client token.
+            substate: The substate to set.
+        """
+        substate_token = _substate_key(client_token, substate)
+
+        if substate._get_was_touched():
+            substate._was_touched = False  # Reset the touched flag after serializing.
+            pickle_state = substate._serialize()
+            if pickle_state:
+                if not self.states_directory.exists():
+                    self.states_directory.mkdir(parents=True, exist_ok=True)
+                self.token_path(substate_token).write_bytes(pickle_state)
+
+        for substate_substate in substate.substates.values():
+            await self.set_state_for_substate(client_token, substate_substate)
+
+    @override
+    async def set_state(self, token: str, state: BaseState):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+        """
+        client_token, substate = _split_substate_key(token)
+        await self.set_state_for_substate(client_token, state)
+
+    @override
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        # Memory state manager ignores the substate suffix and always returns the top-level state.
+        client_token, substate = _split_substate_key(token)
+        if client_token not in self._states_locks:
+            async with self._state_manager_lock:
+                if client_token not in self._states_locks:
+                    self._states_locks[client_token] = asyncio.Lock()
+
+        async with self._states_locks[client_token]:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state)
+
+
+def _default_lock_expiration() -> int:
+    """Get the default lock expiration time.
+
+    Returns:
+        The default lock expiration time.
+    """
+    return get_config().redis_lock_expiration
+
+
+def _default_lock_warning_threshold() -> int:
+    """Get the default lock warning threshold.
+
+    Returns:
+        The default lock warning threshold.
+    """
+    return get_config().redis_lock_warning_threshold
+
+
+@dataclasses.dataclass
+class StateManagerRedis(StateManager):
+    """A state manager that stores states in redis."""
+
+    # The redis client to use.
+    redis: Redis
+
+    # The token expiration time (s).
+    token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
+
+    # The maximum time to hold a lock (ms).
+    lock_expiration: int = dataclasses.field(default_factory=_default_lock_expiration)
+
+    # The maximum time to hold a lock (ms) before warning.
+    lock_warning_threshold: int = dataclasses.field(
+        default_factory=_default_lock_warning_threshold
+    )
+
+    # The keyspace subscription string when redis is waiting for lock to be released.
+    _redis_notify_keyspace_events: str = dataclasses.field(
+        default="K"  # Enable keyspace notifications (target a particular key)
+        "g"  # For generic commands (DEL, EXPIRE, etc)
+        "x"  # For expired events
+        "e"  # For evicted events (i.e. maxmemory exceeded)
+    )
+
+    # These events indicate that a lock is no longer held.
+    _redis_keyspace_lock_release_events: set[bytes] = dataclasses.field(
+        default_factory=lambda: {
+            b"del",
+            b"expire",
+            b"expired",
+            b"evicted",
+        }
+    )
+
+    # Whether keyspace notifications have been enabled.
+    _redis_notify_keyspace_events_enabled: bool = dataclasses.field(default=False)
+
+    # The logical database number used by the redis client.
+    _redis_db: int = dataclasses.field(default=0)
+
+    def __post_init__(self):
+        """Validate the lock warning threshold.
+
+        Raises:
+            InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
+        """
+        if self.lock_warning_threshold >= (lock_expiration := self.lock_expiration):
+            raise InvalidLockWarningThresholdError(
+                f"The lock warning threshold({self.lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
+            )
+
+    def _get_required_state_classes(
+        self,
+        target_state_cls: type[BaseState],
+        subclasses: bool = False,
+        required_state_classes: set[type[BaseState]] | None = None,
+    ) -> set[type[BaseState]]:
+        """Recursively determine which states are required to fetch the target state.
+
+        This will always include potentially dirty substates that depend on vars
+        in the target_state_cls.
+
+        Args:
+            target_state_cls: The target state class being fetched.
+            subclasses: Whether to include subclasses of the target state.
+            required_state_classes: Recursive argument tracking state classes that have already been seen.
+
+        Returns:
+            The set of state classes required to fetch the target state.
+        """
+        if required_state_classes is None:
+            required_state_classes = set()
+        # Get the substates if requested.
+        if subclasses:
+            for substate in target_state_cls.get_substates():
+                self._get_required_state_classes(
+                    substate,
+                    subclasses=True,
+                    required_state_classes=required_state_classes,
+                )
+        if target_state_cls in required_state_classes:
+            return required_state_classes
+        required_state_classes.add(target_state_cls)
+
+        # Get dependent substates.
+        for pd_substates in target_state_cls._get_potentially_dirty_states():
+            self._get_required_state_classes(
+                pd_substates,
+                subclasses=False,
+                required_state_classes=required_state_classes,
+            )
+
+        # Get the parent state if it exists.
+        if parent_state := target_state_cls.get_parent_state():
+            self._get_required_state_classes(
+                parent_state,
+                subclasses=False,
+                required_state_classes=required_state_classes,
+            )
+        return required_state_classes
+
+    def _get_populated_states(
+        self,
+        target_state: BaseState,
+        populated_states: dict[str, BaseState] | None = None,
+    ) -> dict[str, BaseState]:
+        """Recursively determine which states from target_state are already fetched.
+
+        Args:
+            target_state: The state to check for populated states.
+            populated_states: Recursive argument tracking states seen in previous calls.
+
+        Returns:
+            A dictionary of state full name to state instance.
+        """
+        if populated_states is None:
+            populated_states = {}
+        if target_state.get_full_name() in populated_states:
+            return populated_states
+        populated_states[target_state.get_full_name()] = target_state
+        for substate in target_state.substates.values():
+            self._get_populated_states(substate, populated_states=populated_states)
+        if target_state.parent_state is not None:
+            self._get_populated_states(
+                target_state.parent_state, populated_states=populated_states
+            )
+        return populated_states
+
+    @override
+    async def get_state(
+        self,
+        token: str,
+        top_level: bool = True,
+        for_state_instance: BaseState | None = None,
+    ) -> BaseState:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+            top_level: If true, return an instance of the top-level state (self.state).
+            for_state_instance: If provided, attach the requested states to this existing state tree.
+
+        Returns:
+            The state for the token.
+
+        Raises:
+            RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
+                requested state was not fetched.
+        """
+        # Split the actual token from the fully qualified substate name.
+        token, state_path = _split_substate_key(token)
+        if state_path:
+            # Get the State class associated with the given path.
+            state_cls = self.state.get_class_substate(state_path)
+        else:
+            raise RuntimeError(
+                f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
+            )
+
+        # Determine which states we already have.
+        flat_state_tree: dict[str, BaseState] = (
+            self._get_populated_states(for_state_instance) if for_state_instance else {}
+        )
+
+        # Determine which states from the tree need to be fetched.
+        required_state_classes = sorted(
+            self._get_required_state_classes(state_cls, subclasses=True)
+            - {type(s) for s in flat_state_tree.values()},
+            key=lambda x: x.get_full_name(),
+        )
+
+        redis_pipeline = self.redis.pipeline()
+        for state_cls in required_state_classes:
+            redis_pipeline.get(_substate_key(token, state_cls))
+
+        for state_cls, redis_state in zip(
+            required_state_classes,
+            await redis_pipeline.execute(),
+            strict=False,
+        ):
+            state = None
+
+            if redis_state is not None:
+                # Deserialize the substate.
+                with contextlib.suppress(StateSchemaMismatchError):
+                    state = BaseState._deserialize(data=redis_state)
+            if state is None:
+                # Key didn't exist or schema mismatch so create a new instance for this token.
+                state = state_cls(
+                    init_substates=False,
+                    _reflex_internal_init=True,
+                )
+            flat_state_tree[state.get_full_name()] = state
+            if state.get_parent_state() is not None:
+                parent_state_name, _dot, state_name = state.get_full_name().rpartition(
+                    "."
+                )
+                parent_state = flat_state_tree.get(parent_state_name)
+                if parent_state is None:
+                    raise RuntimeError(
+                        f"Parent state for {state.get_full_name()} was not found "
+                        "in the state tree, but should have already been fetched. "
+                        "This is a bug",
+                    )
+                parent_state.substates[state_name] = state
+                state.parent_state = parent_state
+
+        # To retain compatibility with previous implementation, by default, we return
+        # the top-level state which should always be fetched or already cached.
+        if top_level:
+            return flat_state_tree[self.state.get_full_name()]
+        return flat_state_tree[state_cls.get_full_name()]
+
+    @override
+    async def set_state(
+        self,
+        token: str,
+        state: BaseState,
+        lock_id: bytes | None = None,
+    ):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+            lock_id: If provided, the lock_key must be set to this value to set the state.
+
+        Raises:
+            LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
+            RuntimeError: If the state instance doesn't match the state name in the token.
+        """
+        # Check that we're holding the lock.
+        if (
+            lock_id is not None
+            and await self.redis.get(self._lock_key(token)) != lock_id
+        ):
+            raise LockExpiredError(
+                f"Lock expired for token {token} while processing. Consider increasing "
+                f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
+                "or use `@rx.event(background=True)` decorator for long-running tasks."
+            )
+        elif lock_id is not None:
+            time_taken = self.lock_expiration / 1000 - (
+                await self.redis.ttl(self._lock_key(token))
+            )
+            if time_taken > self.lock_warning_threshold / 1000:
+                console.warn(
+                    f"Lock for token {token} was held too long {time_taken=}s, "
+                    f"use `@rx.event(background=True)` decorator for long-running tasks.",
+                    dedupe=True,
+                )
+
+        client_token, substate_name = _split_substate_key(token)
+        # If the substate name on the token doesn't match the instance name, it cannot have a parent.
+        if state.parent_state is not None and state.get_full_name() != substate_name:
+            raise RuntimeError(
+                f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
+            )
+
+        # Recursively set_state on all known substates.
+        tasks = [
+            asyncio.create_task(
+                self.set_state(
+                    _substate_key(client_token, substate),
+                    substate,
+                    lock_id,
+                )
+            )
+            for substate in state.substates.values()
+        ]
+        # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
+        if state._get_was_touched():
+            pickle_state = state._serialize()
+            if pickle_state:
+                await self.redis.set(
+                    _substate_key(client_token, state),
+                    pickle_state,
+                    ex=self.token_expiration,
+                )
+
+        # Wait for substates to be persisted.
+        for t in tasks:
+            await t
+
+    @override
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        async with self._lock(token) as lock_id:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state, lock_id)
+
+    @staticmethod
+    def _lock_key(token: str) -> bytes:
+        """Get the redis key for a token's lock.
+
+        Args:
+            token: The token to get the lock key for.
+
+        Returns:
+            The redis lock key for the token.
+        """
+        # All substates share the same lock domain, so ignore any substate path suffix.
+        client_token = _split_substate_key(token)[0]
+        return f"{client_token}_lock".encode()
+
+    async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
+        """Try to get a redis lock for a token.
+
+        Args:
+            lock_key: The redis key for the lock.
+            lock_id: The ID of the lock.
+
+        Returns:
+            True if the lock was obtained.
+        """
+        return await self.redis.set(
+            lock_key,
+            lock_id,
+            px=self.lock_expiration,
+            nx=True,  # only set if it doesn't exist
+        )
+
+    async def _get_pubsub_message(
+        self, pubsub: PubSub, timeout: float | None = None
+    ) -> None:
+        """Get lock release events from the pubsub.
+
+        Args:
+            pubsub: The pubsub to get a message from.
+            timeout: Remaining time to wait for a message.
+
+        Returns:
+            The message.
+        """
+        if timeout is None:
+            timeout = self.lock_expiration / 1000.0
+
+        started = time.time()
+        message = await pubsub.get_message(
+            ignore_subscribe_messages=True,
+            timeout=timeout,
+        )
+        if (
+            message is None
+            or message["data"] not in self._redis_keyspace_lock_release_events
+        ):
+            remaining = timeout - (time.time() - started)
+            if remaining <= 0:
+                return
+            await self._get_pubsub_message(pubsub, timeout=remaining)
+
+    async def _enable_keyspace_notifications(self):
+        """Enable keyspace notifications for the redis server.
+
+        Raises:
+            ResponseError: when the keyspace config cannot be set.
+        """
+        if self._redis_notify_keyspace_events_enabled:
+            return
+        # Find out which logical database index is being used.
+        self._redis_db = self.redis.get_connection_kwargs().get("db", self._redis_db)
+
+        try:
+            await self.redis.config_set(
+                "notify-keyspace-events",
+                self._redis_notify_keyspace_events,
+            )
+        except ResponseError:
+            # Some redis servers only allow out-of-band configuration, so ignore errors here.
+            if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
+                raise
+        self._redis_notify_keyspace_events_enabled = True
+
+    async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
+        """Wait for a redis lock to be released via pubsub.
+
+        Coroutine will not return until the lock is obtained.
+
+        Args:
+            lock_key: The redis key for the lock.
+            lock_id: The ID of the lock.
+        """
+        # Enable keyspace notifications for the lock key, so we know when it is available.
+        await self._enable_keyspace_notifications()
+        lock_key_channel = f"__keyspace@{self._redis_db}__:{lock_key.decode()}"
+        async with self.redis.pubsub() as pubsub:
+            await pubsub.psubscribe(lock_key_channel)
+            # wait for the lock to be released
+            while True:
+                # fast path
+                if await self._try_get_lock(lock_key, lock_id):
+                    return
+                # wait for lock events
+                await self._get_pubsub_message(pubsub)
+
+    @contextlib.asynccontextmanager
+    async def _lock(self, token: str):
+        """Obtain a redis lock for a token.
+
+        Args:
+            token: The token to obtain a lock for.
+
+        Yields:
+            The ID of the lock (to be passed to set_state).
+
+        Raises:
+            LockExpiredError: If the lock has expired while processing the event.
+        """
+        lock_key = self._lock_key(token)
+        lock_id = uuid.uuid4().hex.encode()
+
+        if not await self._try_get_lock(lock_key, lock_id):
+            # Missed the fast-path to get lock, subscribe for lock delete/expire events
+            await self._wait_lock(lock_key, lock_id)
+        state_is_locked = True
+
+        try:
+            yield lock_id
+        except LockExpiredError:
+            state_is_locked = False
+            raise
+        finally:
+            if state_is_locked:
+                # only delete our lock
+                await self.redis.delete(lock_key)
+
+    async def close(self):
+        """Explicitly close the redis connection and connection_pool.
+
+        It is necessary in testing scenarios to close between asyncio test cases
+        to avoid having lingering redis connections associated with event loops
+        that will be closed (each test case uses its own event loop).
+
+        Note: Connections will be automatically reopened when needed.
+        """
+        await self.redis.aclose(close_connection_pool=True)
+
+
+def get_state_manager() -> StateManager:
+    """Get the state manager for the app that is currently running.
+
+    Returns:
+        The state manager.
+    """
+    return prerequisites.get_and_validate_app().app.state_manager

+ 726 - 2
reflex/istate/proxy.py

@@ -1,8 +1,309 @@
 """A module to hold state proxy classes."""
 
-from typing import Any
+from __future__ import annotations
 
-from reflex.state import StateProxy
+import asyncio
+import copy
+import dataclasses
+import functools
+import inspect
+import json
+from collections.abc import Callable, Sequence
+from types import MethodType
+from typing import TYPE_CHECKING, Any, SupportsIndex
+
+import pydantic
+import wrapt
+from pydantic import BaseModel as BaseModelV2
+from pydantic.v1 import BaseModel as BaseModelV1
+from sqlalchemy.orm import DeclarativeBase
+
+from reflex.base import Base
+from reflex.utils import prerequisites
+from reflex.utils.exceptions import ImmutableStateError
+from reflex.utils.serializers import serializer
+from reflex.vars.base import Var
+
+if TYPE_CHECKING:
+    from reflex.state import BaseState, StateUpdate
+
+
+class StateProxy(wrapt.ObjectProxy):
+    """Proxy of a state instance to control mutability of vars for a background task.
+
+    Since a background task runs against a state instance without holding the
+    state_manager lock for the token, the reference may become stale if the same
+    state is modified by another event handler.
+
+    The proxy object ensures that writes to the state are blocked unless
+    explicitly entering a context which refreshes the state from state_manager
+    and holds the lock for the token until exiting the context. After exiting
+    the context, a StateUpdate may be emitted to the frontend to notify the
+    client of the state change.
+
+    A background task will be passed the `StateProxy` as `self`, so mutability
+    can be safely performed inside an `async with self` block.
+
+        class State(rx.State):
+            counter: int = 0
+
+            @rx.event(background=True)
+            async def bg_increment(self):
+                await asyncio.sleep(1)
+                async with self:
+                    self.counter += 1
+    """
+
+    def __init__(
+        self,
+        state_instance: BaseState,
+        parent_state_proxy: StateProxy | None = None,
+    ):
+        """Create a proxy for a state instance.
+
+        If `get_state` is used on a StateProxy, the resulting state will be
+        linked to the given state via parent_state_proxy. The first state in the
+        chain is the state that initiated the background task.
+
+        Args:
+            state_instance: The state instance to proxy.
+            parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
+        """
+        super().__init__(state_instance)
+        # compile is not relevant to backend logic
+        self._self_app = prerequisites.get_and_validate_app().app
+        self._self_substate_path = tuple(state_instance.get_full_name().split("."))
+        self._self_actx = None
+        self._self_mutable = False
+        self._self_actx_lock = asyncio.Lock()
+        self._self_actx_lock_holder = None
+        self._self_parent_state_proxy = parent_state_proxy
+
+    def _is_mutable(self) -> bool:
+        """Check if the state is mutable.
+
+        Returns:
+            Whether the state is mutable.
+        """
+        if self._self_parent_state_proxy is not None:
+            return self._self_parent_state_proxy._is_mutable() or self._self_mutable
+        return self._self_mutable
+
+    async def __aenter__(self) -> StateProxy:
+        """Enter the async context manager protocol.
+
+        Sets mutability to True and enters the `App.modify_state` async context,
+        which refreshes the state from state_manager and holds the lock for the
+        given state token until exiting the context.
+
+        Background tasks should avoid blocking calls while inside the context.
+
+        Returns:
+            This StateProxy instance in mutable mode.
+
+        Raises:
+            ImmutableStateError: If the state is already mutable.
+        """
+        if self._self_parent_state_proxy is not None:
+            from reflex.state import State
+
+            parent_state = (
+                await self._self_parent_state_proxy.__aenter__()
+            ).__wrapped__
+            super().__setattr__(
+                "__wrapped__",
+                await parent_state.get_state(
+                    State.get_class_substate(self._self_substate_path)
+                ),
+            )
+            return self
+        current_task = asyncio.current_task()
+        if (
+            self._self_actx_lock.locked()
+            and current_task == self._self_actx_lock_holder
+        ):
+            raise ImmutableStateError(
+                "The state is already mutable. Do not nest `async with self` blocks."
+            )
+
+        from reflex.state import _substate_key
+
+        await self._self_actx_lock.acquire()
+        self._self_actx_lock_holder = current_task
+        self._self_actx = self._self_app.modify_state(
+            token=_substate_key(
+                self.__wrapped__.router.session.client_token,
+                self._self_substate_path,
+            )
+        )
+        mutable_state = await self._self_actx.__aenter__()
+        super().__setattr__(
+            "__wrapped__", mutable_state.get_substate(self._self_substate_path)
+        )
+        self._self_mutable = True
+        return self
+
+    async def __aexit__(self, *exc_info: Any) -> None:
+        """Exit the async context manager protocol.
+
+        Sets proxy mutability to False and persists any state changes.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        if self._self_parent_state_proxy is not None:
+            await self._self_parent_state_proxy.__aexit__(*exc_info)
+            return
+        if self._self_actx is None:
+            return
+        self._self_mutable = False
+        try:
+            await self._self_actx.__aexit__(*exc_info)
+        finally:
+            self._self_actx_lock_holder = None
+            self._self_actx_lock.release()
+        self._self_actx = None
+
+    def __enter__(self):
+        """Enter the regular context manager protocol.
+
+        This is not supported for background tasks, and exists only to raise a more useful exception
+        when the StateProxy is used incorrectly.
+
+        Raises:
+            TypeError: always, because only async contextmanager protocol is supported.
+        """
+        raise TypeError("Background task must use `async with self` to modify state.")
+
+    def __exit__(self, *exc_info: Any) -> None:
+        """Exit the regular context manager protocol.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        pass
+
+    def __getattr__(self, name: str) -> Any:
+        """Get the attribute from the underlying state instance.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The value of the attribute.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if name in ["substates", "parent_state"] and not self._is_mutable():
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+
+        value = super().__getattr__(name)
+        if not name.startswith("_self_") and isinstance(value, MutableProxy):
+            # ensure mutations to these containers are blocked unless proxy is _mutable
+            return ImmutableMutableProxy(
+                wrapped=value.__wrapped__,
+                state=self,
+                field_name=value._self_field_name,
+            )
+        if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
+            # Rebind event handler to the proxy instance
+            value = functools.partial(
+                value.func,
+                self,
+                *value.args[1:],
+                **value.keywords,
+            )
+        if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
+            # Rebind methods to the proxy instance
+            value = type(value)(value.__func__, self)
+        return value
+
+    def __setattr__(self, name: str, value: Any) -> None:
+        """Set the attribute on the underlying state instance.
+
+        If the attribute is internal, set it on the proxy instance instead.
+
+        Args:
+            name: The name of the attribute.
+            value: The value of the attribute.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if (
+            name.startswith("_self_")  # wrapper attribute
+            or self._is_mutable()  # lock held
+            # non-persisted state attribute
+            or name in self.__wrapped__.get_skip_vars()
+        ):
+            super().__setattr__(name, value)
+            return
+
+        raise ImmutableStateError(
+            "Background task StateProxy is immutable outside of a context "
+            "manager. Use `async with self` to modify state."
+        )
+
+    def get_substate(self, path: Sequence[str]) -> BaseState:
+        """Only allow substate access with lock held.
+
+        Args:
+            path: The path to the substate.
+
+        Returns:
+            The substate.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not self._is_mutable():
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return self.__wrapped__.get_substate(path)
+
+    async def get_state(self, state_cls: type[BaseState]) -> BaseState:
+        """Get an instance of the state associated with this token.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The state.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not self._is_mutable():
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return type(self)(
+            await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
+        )
+
+    async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
+        """Temporarily allow mutability to access parent_state.
+
+        Args:
+            *args: The args to pass to the underlying state instance.
+            **kwargs: The kwargs to pass to the underlying state instance.
+
+        Returns:
+            The state update.
+        """
+        original_mutable = self._self_mutable
+        self._self_mutable = True
+        try:
+            return await self.__wrapped__._as_state_update(*args, **kwargs)
+        finally:
+            self._self_mutable = original_mutable
 
 
 class ReadOnlyStateProxy(StateProxy):
@@ -31,3 +332,426 @@ class ReadOnlyStateProxy(StateProxy):
             NotImplementedError: Always raised when trying to mark the proxied state as dirty.
         """
         raise NotImplementedError("This is a read-only state proxy.")
+
+
+class MutableProxy(wrapt.ObjectProxy):
+    """A proxy for a mutable object that tracks changes."""
+
+    # Hint for finding the base class of the proxy.
+    __base_proxy__ = "MutableProxy"
+
+    # Methods on wrapped objects which should mark the state as dirty.
+    __mark_dirty_attrs__ = {
+        "add",
+        "append",
+        "clear",
+        "difference_update",
+        "discard",
+        "extend",
+        "insert",
+        "intersection_update",
+        "pop",
+        "popitem",
+        "remove",
+        "reverse",
+        "setdefault",
+        "sort",
+        "symmetric_difference_update",
+        "update",
+    }
+
+    # Methods on wrapped objects might return mutable objects that should be tracked.
+    __wrap_mutable_attrs__ = {
+        "get",
+        "setdefault",
+    }
+
+    # These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
+    __never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
+        pydantic.BaseModel.__dict__
+    )
+
+    # These types will be wrapped in MutableProxy
+    __mutable_types__ = (
+        list,
+        dict,
+        set,
+        Base,
+        DeclarativeBase,
+        BaseModelV2,
+        BaseModelV1,
+    )
+
+    # Dynamically generated classes for tracking dataclass mutations.
+    __dataclass_proxies__: dict[type, type] = {}
+
+    def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
+        """Create a proxy instance for a mutable object that tracks changes.
+
+        Args:
+            wrapped: The object to proxy.
+            *args: Other args passed to MutableProxy (ignored).
+            **kwargs: Other kwargs passed to MutableProxy (ignored).
+
+        Returns:
+            The proxy instance.
+        """
+        if dataclasses.is_dataclass(wrapped):
+            wrapped_cls = type(wrapped)
+            wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
+            # Find the associated class
+            if wrapper_cls_name not in cls.__dataclass_proxies__:
+                # Create a new class that has the __dataclass_fields__ defined
+                cls.__dataclass_proxies__[wrapper_cls_name] = type(
+                    wrapper_cls_name,
+                    (cls,),
+                    {
+                        dataclasses._FIELDS: getattr(  # pyright: ignore [reportAttributeAccessIssue]
+                            wrapped_cls,
+                            dataclasses._FIELDS,  # pyright: ignore [reportAttributeAccessIssue]
+                        ),
+                    },
+                )
+            cls = cls.__dataclass_proxies__[wrapper_cls_name]
+        return super().__new__(cls)
+
+    def __init__(self, wrapped: Any, state: BaseState, field_name: str):
+        """Create a proxy for a mutable object that tracks changes.
+
+        Args:
+            wrapped: The object to proxy.
+            state: The state to mark dirty when the object is changed.
+            field_name: The name of the field on the state associated with the
+                wrapped object.
+        """
+        super().__init__(wrapped)
+        self._self_state = state
+        self._self_field_name = field_name
+
+    def __repr__(self) -> str:
+        """Get the representation of the wrapped object.
+
+        Returns:
+            The representation of the wrapped object.
+        """
+        return f"{type(self).__name__}({self.__wrapped__})"
+
+    def _mark_dirty(
+        self,
+        wrapped: Callable | None = None,
+        instance: BaseState | None = None,
+        args: tuple = (),
+        kwargs: dict | None = None,
+    ) -> Any:
+        """Mark the state as dirty, then call a wrapped function.
+
+        Intended for use with `FunctionWrapper` from the `wrapt` library.
+
+        Args:
+            wrapped: The wrapped function.
+            instance: The instance of the wrapped function.
+            args: The args for the wrapped function.
+            kwargs: The kwargs for the wrapped function.
+
+        Returns:
+            The result of the wrapped function.
+        """
+        self._self_state.dirty_vars.add(self._self_field_name)
+        self._self_state._mark_dirty()
+        if wrapped is not None:
+            return wrapped(*args, **(kwargs or {}))
+
+    @classmethod
+    def _is_mutable_type(cls, value: Any) -> bool:
+        """Check if a value is of a mutable type and should be wrapped.
+
+        Args:
+            value: The value to check.
+
+        Returns:
+            Whether the value is of a mutable type.
+        """
+        return isinstance(value, cls.__mutable_types__) or (
+            dataclasses.is_dataclass(value) and not isinstance(value, Var)
+        )
+
+    @staticmethod
+    def _is_called_from_dataclasses_internal() -> bool:
+        """Check if the current function is called from dataclasses helper.
+
+        Returns:
+            Whether the current function is called from dataclasses internal code.
+        """
+        # Walk up the stack a bit to see if we are called from dataclasses
+        # internal code, for example `asdict` or `astuple`.
+        frame = inspect.currentframe()
+        for _ in range(5):
+            # Why not `inspect.stack()` -- this is much faster!
+            if not (frame := frame and frame.f_back):
+                break
+            if inspect.getfile(frame) == dataclasses.__file__:
+                return True
+        return False
+
+    def _wrap_recursive(self, value: Any) -> Any:
+        """Wrap a value recursively if it is mutable.
+
+        Args:
+            value: The value to wrap.
+
+        Returns:
+            The wrapped value.
+        """
+        # When called from dataclasses internal code, return the unwrapped value
+        if self._is_called_from_dataclasses_internal():
+            return value
+        # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
+        if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
+            base_cls = globals()[self.__base_proxy__]
+            return base_cls(
+                wrapped=value,
+                state=self._self_state,
+                field_name=self._self_field_name,
+            )
+        return value
+
+    def _wrap_recursive_decorator(
+        self, wrapped: Callable, instance: BaseState, args: list, kwargs: dict
+    ) -> Any:
+        """Wrap a function that returns a possibly mutable value.
+
+        Intended for use with `FunctionWrapper` from the `wrapt` library.
+
+        Args:
+            wrapped: The wrapped function.
+            instance: The instance of the wrapped function.
+            args: The args for the wrapped function.
+            kwargs: The kwargs for the wrapped function.
+
+        Returns:
+            The result of the wrapped function (possibly wrapped in a MutableProxy).
+        """
+        return self._wrap_recursive(wrapped(*args, **kwargs))
+
+    def __getattr__(self, __name: str) -> Any:
+        """Get the attribute on the proxied object and return a proxy if mutable.
+
+        Args:
+            __name: The name of the attribute.
+
+        Returns:
+            The attribute value.
+        """
+        value = super().__getattr__(__name)
+
+        if callable(value):
+            if __name in self.__mark_dirty_attrs__:
+                # Wrap special callables, like "append", which should mark state dirty.
+                value = wrapt.FunctionWrapper(value, self._mark_dirty)
+
+            if __name in self.__wrap_mutable_attrs__:
+                # Wrap methods that may return mutable objects tied to the state.
+                value = wrapt.FunctionWrapper(
+                    value,
+                    self._wrap_recursive_decorator,
+                )
+
+            if (
+                isinstance(self.__wrapped__, Base)
+                and __name not in self.__never_wrap_base_attrs__
+                and hasattr(value, "__func__")
+            ):
+                # Wrap methods called on Base subclasses, which might do _anything_
+                return wrapt.FunctionWrapper(
+                    functools.partial(value.__func__, self),  # pyright: ignore [reportFunctionMemberAccess]
+                    self._wrap_recursive_decorator,
+                )
+
+        if self._is_mutable_type(value) and __name not in (
+            "__wrapped__",
+            "_self_state",
+            "__dict__",
+        ):
+            # Recursively wrap mutable attribute values retrieved through this proxy.
+            return self._wrap_recursive(value)
+
+        return value
+
+    def __getitem__(self, key: Any) -> Any:
+        """Get the item on the proxied object and return a proxy if mutable.
+
+        Args:
+            key: The key of the item.
+
+        Returns:
+            The item value.
+        """
+        value = super().__getitem__(key)
+        if isinstance(key, slice) and isinstance(value, list):
+            return [self._wrap_recursive(item) for item in value]
+        # Recursively wrap mutable items retrieved through this proxy.
+        return self._wrap_recursive(value)
+
+    def __iter__(self) -> Any:
+        """Iterate over the proxied object and return a proxy if mutable.
+
+        Yields:
+            Each item value (possibly wrapped in MutableProxy).
+        """
+        for value in super().__iter__():
+            # Recursively wrap mutable items retrieved through this proxy.
+            yield self._wrap_recursive(value)
+
+    def __delattr__(self, name: str):
+        """Delete the attribute on the proxied object and mark state dirty.
+
+        Args:
+            name: The name of the attribute.
+        """
+        self._mark_dirty(super().__delattr__, args=(name,))
+
+    def __delitem__(self, key: str):
+        """Delete the item on the proxied object and mark state dirty.
+
+        Args:
+            key: The key of the item.
+        """
+        self._mark_dirty(super().__delitem__, args=(key,))
+
+    def __setitem__(self, key: str, value: Any):
+        """Set the item on the proxied object and mark state dirty.
+
+        Args:
+            key: The key of the item.
+            value: The value of the item.
+        """
+        self._mark_dirty(super().__setitem__, args=(key, value))
+
+    def __setattr__(self, name: str, value: Any):
+        """Set the attribute on the proxied object and mark state dirty.
+
+        If the attribute starts with "_self_", then the state is NOT marked
+        dirty as these are internal proxy attributes.
+
+        Args:
+            name: The name of the attribute.
+            value: The value of the attribute.
+        """
+        if name.startswith("_self_"):
+            # Special case attributes of the proxy itself, not applied to the wrapped object.
+            super().__setattr__(name, value)
+            return
+        self._mark_dirty(super().__setattr__, args=(name, value))
+
+    def __copy__(self) -> Any:
+        """Return a copy of the proxy.
+
+        Returns:
+            A copy of the wrapped object, unconnected to the proxy.
+        """
+        return copy.copy(self.__wrapped__)
+
+    def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Any:
+        """Return a deepcopy of the proxy.
+
+        Args:
+            memo: The memo dict to use for the deepcopy.
+
+        Returns:
+            A deepcopy of the wrapped object, unconnected to the proxy.
+        """
+        return copy.deepcopy(self.__wrapped__, memo=memo)
+
+    def __reduce_ex__(self, protocol_version: SupportsIndex):
+        """Get the state for redis serialization.
+
+        This method is called by cloudpickle to serialize the object.
+
+        It explicitly serializes the wrapped object, stripping off the mutable proxy.
+
+        Args:
+            protocol_version: The protocol version.
+
+        Returns:
+            Tuple of (wrapped class, empty args, class __getstate__)
+        """
+        return self.__wrapped__.__reduce_ex__(protocol_version)
+
+
+@serializer
+def serialize_mutable_proxy(mp: MutableProxy):
+    """Return the wrapped value of a MutableProxy.
+
+    Args:
+        mp: The MutableProxy to serialize.
+
+    Returns:
+        The wrapped object.
+    """
+    return mp.__wrapped__
+
+
+_orig_json_encoder_default = json.JSONEncoder.default
+
+
+def _json_encoder_default_wrapper(self: json.JSONEncoder, o: Any) -> Any:
+    """Wrap JSONEncoder.default to handle MutableProxy objects.
+
+    Args:
+        self: the JSONEncoder instance.
+        o: the object to serialize.
+
+    Returns:
+        A JSON-able object.
+    """
+    try:
+        return o.__wrapped__
+    except AttributeError:
+        pass
+    return _orig_json_encoder_default(self, o)
+
+
+json.JSONEncoder.default = _json_encoder_default_wrapper
+
+
+class ImmutableMutableProxy(MutableProxy):
+    """A proxy for a mutable object that tracks changes.
+
+    This wrapper comes from StateProxy, and will raise an exception if an attempt is made
+    to modify the wrapped object when the StateProxy is immutable.
+    """
+
+    # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
+    __base_proxy__ = "ImmutableMutableProxy"
+
+    def _mark_dirty(
+        self,
+        wrapped: Callable | None = None,
+        instance: BaseState | None = None,
+        args: tuple = (),
+        kwargs: dict | None = None,
+    ) -> Any:
+        """Raise an exception when an attempt is made to modify the object.
+
+        Intended for use with `FunctionWrapper` from the `wrapt` library.
+
+        Args:
+            wrapped: The wrapped function.
+            instance: The instance of the wrapped function.
+            args: The args for the wrapped function.
+            kwargs: The kwargs for the wrapped function.
+
+        Returns:
+            The result of the wrapped function.
+
+        Raises:
+            ImmutableStateError: if the StateProxy is not mutable.
+        """
+        if not self._self_state._is_mutable():
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return super()._mark_dirty(
+            wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
+        )

+ 74 - 1624
reflex/state.py

@@ -9,24 +9,19 @@ import copy
 import dataclasses
 import functools
 import inspect
-import json
 import pickle
 import sys
-import time
 import typing
-import uuid
 import warnings
-from abc import ABC, abstractmethod
+from abc import ABC
 from collections.abc import AsyncIterator, Callable, Sequence
 from hashlib import md5
-from pathlib import Path
-from types import FunctionType, MethodType
+from types import FunctionType
 from typing import (
     TYPE_CHECKING,
     Any,
     BinaryIO,
     ClassVar,
-    SupportsIndex,
     TypeVar,
     cast,
     get_args,
@@ -34,22 +29,16 @@ from typing import (
 )
 
 import pydantic.v1 as pydantic
-import wrapt
 from pydantic import BaseModel as BaseModelV2
 from pydantic.v1 import BaseModel as BaseModelV1
-from pydantic.v1 import validator
 from pydantic.v1.fields import ModelField
-from redis.asyncio import Redis
-from redis.asyncio.client import PubSub
-from redis.exceptions import ResponseError
 from rich.markup import escape
-from sqlalchemy.orm import DeclarativeBase
 from typing_extensions import Self
 
 import reflex.istate.dynamic
 from reflex import constants, event
 from reflex.base import Base
-from reflex.config import PerformanceMode, environment, get_config
+from reflex.config import PerformanceMode, environment
 from reflex.event import (
     BACKGROUND_TASK_MARKER,
     Event,
@@ -58,19 +47,17 @@ from reflex.event import (
     fix_events,
 )
 from reflex.istate.data import RouterData
+from reflex.istate.proxy import ImmutableMutableProxy as ImmutableMutableProxy
+from reflex.istate.proxy import MutableProxy, StateProxy
 from reflex.istate.storage import ClientStorageBase
 from reflex.model import Model
-from reflex.utils import console, format, path_ops, prerequisites, types
+from reflex.utils import console, format, prerequisites, types
 from reflex.utils.exceptions import (
     ComputedVarShadowsBaseVarsError,
     ComputedVarShadowsStateVarError,
     DynamicComponentInvalidSignatureError,
     DynamicRouteArgShadowsStateVarError,
     EventHandlerShadowsBuiltInStateMethodError,
-    ImmutableStateError,
-    InvalidLockWarningThresholdError,
-    InvalidStateManagerModeError,
-    LockExpiredError,
     ReflexRuntimeError,
     SetUndefinedStateVarError,
     StateMismatchError,
@@ -79,13 +66,12 @@ from reflex.utils.exceptions import (
     StateTooLargeError,
     UnretrievableVarValueError,
 )
+from reflex.utils.exceptions import ImmutableStateError as ImmutableStateError
 from reflex.utils.exec import is_testing_env
-from reflex.utils.serializers import serializer
 from reflex.utils.types import (
     _isinstance,
     get_origin,
     is_union,
-    override,
     true_type_for_pydantic_field,
     value_inside_optional,
 )
@@ -2284,6 +2270,35 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return state
 
 
+def _serialize_type(type_: Any) -> str:
+    """Serialize a type.
+
+    Args:
+        type_: The type to serialize.
+
+    Returns:
+        The serialized type.
+    """
+    if not inspect.isclass(type_):
+        return f"{type_}"
+    return f"{type_.__module__}.{type_.__qualname__}"
+
+
+def is_serializable(value: Any) -> bool:
+    """Check if a value is serializable.
+
+    Args:
+        value: The value to check.
+
+    Returns:
+        Whether the value is serializable.
+    """
+    try:
+        return bool(pickle.dumps(value))
+    except Exception:
+        return False
+
+
 T_STATE = TypeVar("T_STATE", bound=BaseState)
 
 
@@ -2523,278 +2538,6 @@ class ComponentState(State, mixin=True):
         return component
 
 
-class StateProxy(wrapt.ObjectProxy):
-    """Proxy of a state instance to control mutability of vars for a background task.
-
-    Since a background task runs against a state instance without holding the
-    state_manager lock for the token, the reference may become stale if the same
-    state is modified by another event handler.
-
-    The proxy object ensures that writes to the state are blocked unless
-    explicitly entering a context which refreshes the state from state_manager
-    and holds the lock for the token until exiting the context. After exiting
-    the context, a StateUpdate may be emitted to the frontend to notify the
-    client of the state change.
-
-    A background task will be passed the `StateProxy` as `self`, so mutability
-    can be safely performed inside an `async with self` block.
-
-        class State(rx.State):
-            counter: int = 0
-
-            @rx.event(background=True)
-            async def bg_increment(self):
-                await asyncio.sleep(1)
-                async with self:
-                    self.counter += 1
-    """
-
-    def __init__(
-        self,
-        state_instance: BaseState,
-        parent_state_proxy: StateProxy | None = None,
-    ):
-        """Create a proxy for a state instance.
-
-        If `get_state` is used on a StateProxy, the resulting state will be
-        linked to the given state via parent_state_proxy. The first state in the
-        chain is the state that initiated the background task.
-
-        Args:
-            state_instance: The state instance to proxy.
-            parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
-        """
-        super().__init__(state_instance)
-        # compile is not relevant to backend logic
-        self._self_app = prerequisites.get_and_validate_app().app
-        self._self_substate_path = tuple(state_instance.get_full_name().split("."))
-        self._self_actx = None
-        self._self_mutable = False
-        self._self_actx_lock = asyncio.Lock()
-        self._self_actx_lock_holder = None
-        self._self_parent_state_proxy = parent_state_proxy
-
-    def _is_mutable(self) -> bool:
-        """Check if the state is mutable.
-
-        Returns:
-            Whether the state is mutable.
-        """
-        if self._self_parent_state_proxy is not None:
-            return self._self_parent_state_proxy._is_mutable() or self._self_mutable
-        return self._self_mutable
-
-    async def __aenter__(self) -> StateProxy:
-        """Enter the async context manager protocol.
-
-        Sets mutability to True and enters the `App.modify_state` async context,
-        which refreshes the state from state_manager and holds the lock for the
-        given state token until exiting the context.
-
-        Background tasks should avoid blocking calls while inside the context.
-
-        Returns:
-            This StateProxy instance in mutable mode.
-
-        Raises:
-            ImmutableStateError: If the state is already mutable.
-        """
-        if self._self_parent_state_proxy is not None:
-            parent_state = (
-                await self._self_parent_state_proxy.__aenter__()
-            ).__wrapped__
-            super().__setattr__(
-                "__wrapped__",
-                await parent_state.get_state(
-                    State.get_class_substate(self._self_substate_path)
-                ),
-            )
-            return self
-        current_task = asyncio.current_task()
-        if (
-            self._self_actx_lock.locked()
-            and current_task == self._self_actx_lock_holder
-        ):
-            raise ImmutableStateError(
-                "The state is already mutable. Do not nest `async with self` blocks."
-            )
-        await self._self_actx_lock.acquire()
-        self._self_actx_lock_holder = current_task
-        self._self_actx = self._self_app.modify_state(
-            token=_substate_key(
-                self.__wrapped__.router.session.client_token,
-                self._self_substate_path,
-            )
-        )
-        mutable_state = await self._self_actx.__aenter__()
-        super().__setattr__(
-            "__wrapped__", mutable_state.get_substate(self._self_substate_path)
-        )
-        self._self_mutable = True
-        return self
-
-    async def __aexit__(self, *exc_info: Any) -> None:
-        """Exit the async context manager protocol.
-
-        Sets proxy mutability to False and persists any state changes.
-
-        Args:
-            exc_info: The exception info tuple.
-        """
-        if self._self_parent_state_proxy is not None:
-            await self._self_parent_state_proxy.__aexit__(*exc_info)
-            return
-        if self._self_actx is None:
-            return
-        self._self_mutable = False
-        try:
-            await self._self_actx.__aexit__(*exc_info)
-        finally:
-            self._self_actx_lock_holder = None
-            self._self_actx_lock.release()
-        self._self_actx = None
-
-    def __enter__(self):
-        """Enter the regular context manager protocol.
-
-        This is not supported for background tasks, and exists only to raise a more useful exception
-        when the StateProxy is used incorrectly.
-
-        Raises:
-            TypeError: always, because only async contextmanager protocol is supported.
-        """
-        raise TypeError("Background task must use `async with self` to modify state.")
-
-    def __exit__(self, *exc_info: Any) -> None:
-        """Exit the regular context manager protocol.
-
-        Args:
-            exc_info: The exception info tuple.
-        """
-        pass
-
-    def __getattr__(self, name: str) -> Any:
-        """Get the attribute from the underlying state instance.
-
-        Args:
-            name: The name of the attribute.
-
-        Returns:
-            The value of the attribute.
-
-        Raises:
-            ImmutableStateError: If the state is not in mutable mode.
-        """
-        if name in ["substates", "parent_state"] and not self._is_mutable():
-            raise ImmutableStateError(
-                "Background task StateProxy is immutable outside of a context "
-                "manager. Use `async with self` to modify state."
-            )
-        value = super().__getattr__(name)
-        if not name.startswith("_self_") and isinstance(value, MutableProxy):
-            # ensure mutations to these containers are blocked unless proxy is _mutable
-            return ImmutableMutableProxy(
-                wrapped=value.__wrapped__,
-                state=self,
-                field_name=value._self_field_name,
-            )
-        if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
-            # Rebind event handler to the proxy instance
-            value = functools.partial(
-                value.func,
-                self,
-                *value.args[1:],
-                **value.keywords,
-            )
-        if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
-            # Rebind methods to the proxy instance
-            value = type(value)(value.__func__, self)
-        return value
-
-    def __setattr__(self, name: str, value: Any) -> None:
-        """Set the attribute on the underlying state instance.
-
-        If the attribute is internal, set it on the proxy instance instead.
-
-        Args:
-            name: The name of the attribute.
-            value: The value of the attribute.
-
-        Raises:
-            ImmutableStateError: If the state is not in mutable mode.
-        """
-        if (
-            name.startswith("_self_")  # wrapper attribute
-            or self._is_mutable()  # lock held
-            # non-persisted state attribute
-            or name in self.__wrapped__.get_skip_vars()
-        ):
-            super().__setattr__(name, value)
-            return
-
-        raise ImmutableStateError(
-            "Background task StateProxy is immutable outside of a context "
-            "manager. Use `async with self` to modify state."
-        )
-
-    def get_substate(self, path: Sequence[str]) -> BaseState:
-        """Only allow substate access with lock held.
-
-        Args:
-            path: The path to the substate.
-
-        Returns:
-            The substate.
-
-        Raises:
-            ImmutableStateError: If the state is not in mutable mode.
-        """
-        if not self._is_mutable():
-            raise ImmutableStateError(
-                "Background task StateProxy is immutable outside of a context "
-                "manager. Use `async with self` to modify state."
-            )
-        return self.__wrapped__.get_substate(path)
-
-    async def get_state(self, state_cls: type[BaseState]) -> BaseState:
-        """Get an instance of the state associated with this token.
-
-        Args:
-            state_cls: The class of the state.
-
-        Returns:
-            The state.
-
-        Raises:
-            ImmutableStateError: If the state is not in mutable mode.
-        """
-        if not self._is_mutable():
-            raise ImmutableStateError(
-                "Background task StateProxy is immutable outside of a context "
-                "manager. Use `async with self` to modify state."
-            )
-        return type(self)(
-            await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
-        )
-
-    async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
-        """Temporarily allow mutability to access parent_state.
-
-        Args:
-            *args: The args to pass to the underlying state instance.
-            **kwargs: The kwargs to pass to the underlying state instance.
-
-        Returns:
-            The state update.
-        """
-        original_mutable = self._self_mutable
-        self._self_mutable = True
-        try:
-            return await self.__wrapped__._as_state_update(*args, **kwargs)
-        finally:
-            self._self_mutable = original_mutable
-
-
 @dataclasses.dataclass(
     frozen=True,
 )
@@ -2819,1347 +2562,54 @@ class StateUpdate:
         return format.json_dumps(self)
 
 
-class StateManager(Base, ABC):
-    """A class to manage many client states."""
-
-    # The state class to use.
-    state: type[BaseState]
-
-    @classmethod
-    def create(cls, state: type[BaseState]):
-        """Create a new state manager.
-
-        Args:
-            state: The state class to use.
-
-        Raises:
-            InvalidStateManagerModeError: If the state manager mode is invalid.
-
-        Returns:
-            The state manager (either disk, memory or redis).
-        """
-        config = get_config()
-        if prerequisites.parse_redis_url() is not None:
-            config.state_manager_mode = constants.StateManagerMode.REDIS
-        if config.state_manager_mode == constants.StateManagerMode.MEMORY:
-            return StateManagerMemory(state=state)
-        if config.state_manager_mode == constants.StateManagerMode.DISK:
-            return StateManagerDisk(state=state)
-        if config.state_manager_mode == constants.StateManagerMode.REDIS:
-            redis = prerequisites.get_redis()
-            if redis is not None:
-                # make sure expiration values are obtained only from the config object on creation
-                return StateManagerRedis(
-                    state=state,
-                    redis=redis,
-                    token_expiration=config.redis_token_expiration,
-                    lock_expiration=config.redis_lock_expiration,
-                    lock_warning_threshold=config.redis_lock_warning_threshold,
-                )
-        raise InvalidStateManagerModeError(
-            f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
-        )
-
-    @abstractmethod
-    async def get_state(self, token: str) -> BaseState:
-        """Get the state for a token.
-
-        Args:
-            token: The token to get the state for.
-
-        Returns:
-            The state for the token.
-        """
-        pass
-
-    @abstractmethod
-    async def set_state(self, token: str, state: BaseState):
-        """Set the state for a token.
-
-        Args:
-            token: The token to set the state for.
-            state: The state to set.
-        """
-        pass
-
-    @abstractmethod
-    @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
-        """Modify the state for a token while holding exclusive lock.
-
-        Args:
-            token: The token to modify the state for.
-
-        Yields:
-            The state for the token.
-        """
-        yield self.state()
-
-
-class StateManagerMemory(StateManager):
-    """A state manager that stores states in memory."""
-
-    # The mapping of client ids to states.
-    states: dict[str, BaseState] = {}
-
-    # The mutex ensures the dict of mutexes is updated exclusively
-    _state_manager_lock = asyncio.Lock()
-
-    # The dict of mutexes for each client
-    _states_locks: dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
-
-    class Config:  # pyright: ignore [reportIncompatibleVariableOverride]
-        """The Pydantic config."""
-
-        fields = {
-            "_states_locks": {"exclude": True},
-        }
-
-    @override
-    async def get_state(self, token: str) -> BaseState:
-        """Get the state for a token.
-
-        Args:
-            token: The token to get the state for.
-
-        Returns:
-            The state for the token.
-        """
-        # Memory state manager ignores the substate suffix and always returns the top-level state.
-        token = _split_substate_key(token)[0]
-        if token not in self.states:
-            self.states[token] = self.state(_reflex_internal_init=True)
-        return self.states[token]
-
-    @override
-    async def set_state(self, token: str, state: BaseState):
-        """Set the state for a token.
-
-        Args:
-            token: The token to set the state for.
-            state: The state to set.
-        """
-        pass
-
-    @override
-    @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
-        """Modify the state for a token while holding exclusive lock.
-
-        Args:
-            token: The token to modify the state for.
-
-        Yields:
-            The state for the token.
-        """
-        # Memory state manager ignores the substate suffix and always returns the top-level state.
-        token = _split_substate_key(token)[0]
-        if token not in self._states_locks:
-            async with self._state_manager_lock:
-                if token not in self._states_locks:
-                    self._states_locks[token] = asyncio.Lock()
-
-        async with self._states_locks[token]:
-            state = await self.get_state(token)
-            yield state
-            await self.set_state(token, state)
-
-
-def _default_token_expiration() -> int:
-    """Get the default token expiration time.
-
-    Returns:
-        The default token expiration time.
-    """
-    return get_config().redis_token_expiration
-
-
-def _serialize_type(type_: Any) -> str:
-    """Serialize a type.
+def code_uses_state_contexts(javascript_code: str) -> bool:
+    """Check if the rendered Javascript uses state contexts.
 
     Args:
-        type_: The type to serialize.
+        javascript_code: The Javascript code to check.
 
     Returns:
-        The serialized type.
+        True if the code attempts to access a member of StateContexts.
     """
-    if not inspect.isclass(type_):
-        return f"{type_}"
-    return f"{type_.__module__}.{type_.__qualname__}"
+    return bool("useContext(StateContexts" in javascript_code)
 
 
-def is_serializable(value: Any) -> bool:
-    """Check if a value is serializable.
+def reload_state_module(
+    module: str,
+    state: type[BaseState] = State,
+) -> None:
+    """Reset rx.State subclasses to avoid conflict when reloading.
 
     Args:
-        value: The value to check.
+        module: The module to reload.
+        state: Recursive argument for the state class to reload.
 
-    Returns:
-        Whether the value is serializable.
     """
-    try:
-        return bool(pickle.dumps(value))
-    except Exception:
-        return False
-
-
-def reset_disk_state_manager():
-    """Reset the disk state manager."""
-    states_directory = prerequisites.get_states_dir()
-    if states_directory.exists():
-        for path in states_directory.iterdir():
-            path.unlink()
+    # Clean out all potentially dirty states of reloaded modules.
+    for pd_state in tuple(state._potentially_dirty_states):
+        with contextlib.suppress(ValueError):
+            if (
+                state.get_root_state().get_class_substate(pd_state).__module__ == module
+                and module is not None
+            ):
+                state._potentially_dirty_states.remove(pd_state)
+    for subclass in tuple(state.class_subclasses):
+        reload_state_module(module=module, state=subclass)
+        if subclass.__module__ == module and module is not None:
+            all_base_state_classes.pop(subclass.get_full_name(), None)
+            state.class_subclasses.remove(subclass)
+            state._always_dirty_substates.discard(subclass.get_name())
+            state._var_dependencies = {}
+            state._init_var_dependency_dicts()
+    state.get_class_substate.cache_clear()
 
 
-class StateManagerDisk(StateManager):
-    """A state manager that stores states in memory."""
-
-    # The mapping of client ids to states.
-    states: dict[str, BaseState] = {}
-
-    # The mutex ensures the dict of mutexes is updated exclusively
-    _state_manager_lock = asyncio.Lock()
-
-    # The dict of mutexes for each client
-    _states_locks: dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
-
-    # The token expiration time (s).
-    token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
-
-    class Config:  # pyright: ignore [reportIncompatibleVariableOverride]
-        """The Pydantic config."""
-
-        fields = {
-            "_states_locks": {"exclude": True},
-        }
-        keep_untouched = (functools.cached_property,)
-
-    def __init__(self, state: type[BaseState]):
-        """Create a new state manager.
-
-        Args:
-            state: The state class to use.
-        """
-        super().__init__(state=state)
-
-        path_ops.mkdir(self.states_directory)
-
-        self._purge_expired_states()
-
-    @functools.cached_property
-    def states_directory(self) -> Path:
-        """Get the states directory.
-
-        Returns:
-            The states directory.
-        """
-        return prerequisites.get_states_dir()
-
-    def _purge_expired_states(self):
-        """Purge expired states from the disk."""
-        import time
-
-        for path in path_ops.ls(self.states_directory):
-            # check path is a pickle file
-            if path.suffix != ".pkl":
-                continue
-
-            # load last edited field from file
-            last_edited = path.stat().st_mtime
-
-            # check if the file is older than the token expiration time
-            if time.time() - last_edited > self.token_expiration:
-                # remove the file
-                path.unlink()
-
-    def token_path(self, token: str) -> Path:
-        """Get the path for a token.
-
-        Args:
-            token: The token to get the path for.
-
-        Returns:
-            The path for the token.
-        """
-        return (
-            self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
-        ).absolute()
-
-    async def load_state(self, token: str) -> BaseState | None:
-        """Load a state object based on the provided token.
-
-        Args:
-            token: The token used to identify the state object.
-
-        Returns:
-            The loaded state object or None.
-        """
-        token_path = self.token_path(token)
-
-        if token_path.exists():
-            try:
-                with token_path.open(mode="rb") as file:
-                    return BaseState._deserialize(fp=file)
-            except Exception:
-                pass
-
-    async def populate_substates(
-        self, client_token: str, state: BaseState, root_state: BaseState
-    ):
-        """Populate the substates of a state object.
-
-        Args:
-            client_token: The client token.
-            state: The state object to populate.
-            root_state: The root state object.
-        """
-        for substate in state.get_substates():
-            substate_token = _substate_key(client_token, substate)
-
-            fresh_instance = await root_state.get_state(substate)
-            instance = await self.load_state(substate_token)
-            if instance is not None:
-                # Ensure all substates exist, even if they weren't serialized previously.
-                instance.substates = fresh_instance.substates
-            else:
-                instance = fresh_instance
-            state.substates[substate.get_name()] = instance
-            instance.parent_state = state
-
-            await self.populate_substates(client_token, instance, root_state)
-
-    @override
-    async def get_state(
-        self,
-        token: str,
-    ) -> BaseState:
-        """Get the state for a token.
-
-        Args:
-            token: The token to get the state for.
-
-        Returns:
-            The state for the token.
-        """
-        client_token = _split_substate_key(token)[0]
-        root_state = self.states.get(client_token)
-        if root_state is not None:
-            # Retrieved state from memory.
-            return root_state
-
-        # Deserialize root state from disk.
-        root_state = await self.load_state(_substate_key(client_token, self.state))
-        # Create a new root state tree with all substates instantiated.
-        fresh_root_state = self.state(_reflex_internal_init=True)
-        if root_state is None:
-            root_state = fresh_root_state
-        else:
-            # Ensure all substates exist, even if they were not serialized previously.
-            root_state.substates = fresh_root_state.substates
-        self.states[client_token] = root_state
-        await self.populate_substates(client_token, root_state, root_state)
-        return root_state
-
-    async def set_state_for_substate(self, client_token: str, substate: BaseState):
-        """Set the state for a substate.
-
-        Args:
-            client_token: The client token.
-            substate: The substate to set.
-        """
-        substate_token = _substate_key(client_token, substate)
-
-        if substate._get_was_touched():
-            substate._was_touched = False  # Reset the touched flag after serializing.
-            pickle_state = substate._serialize()
-            if pickle_state:
-                if not self.states_directory.exists():
-                    self.states_directory.mkdir(parents=True, exist_ok=True)
-                self.token_path(substate_token).write_bytes(pickle_state)
-
-        for substate_substate in substate.substates.values():
-            await self.set_state_for_substate(client_token, substate_substate)
-
-    @override
-    async def set_state(self, token: str, state: BaseState):
-        """Set the state for a token.
-
-        Args:
-            token: The token to set the state for.
-            state: The state to set.
-        """
-        client_token, substate = _split_substate_key(token)
-        await self.set_state_for_substate(client_token, state)
-
-    @override
-    @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
-        """Modify the state for a token while holding exclusive lock.
-
-        Args:
-            token: The token to modify the state for.
-
-        Yields:
-            The state for the token.
-        """
-        # Memory state manager ignores the substate suffix and always returns the top-level state.
-        client_token, substate = _split_substate_key(token)
-        if client_token not in self._states_locks:
-            async with self._state_manager_lock:
-                if client_token not in self._states_locks:
-                    self._states_locks[client_token] = asyncio.Lock()
-
-        async with self._states_locks[client_token]:
-            state = await self.get_state(token)
-            yield state
-            await self.set_state(token, state)
-
-
-def _default_lock_expiration() -> int:
-    """Get the default lock expiration time.
-
-    Returns:
-        The default lock expiration time.
-    """
-    return get_config().redis_lock_expiration
-
-
-def _default_lock_warning_threshold() -> int:
-    """Get the default lock warning threshold.
-
-    Returns:
-        The default lock warning threshold.
-    """
-    return get_config().redis_lock_warning_threshold
-
-
-class StateManagerRedis(StateManager):
-    """A state manager that stores states in redis."""
-
-    # The redis client to use.
-    redis: Redis
-
-    # The token expiration time (s).
-    token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
-
-    # The maximum time to hold a lock (ms).
-    lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
-
-    # The maximum time to hold a lock (ms) before warning.
-    lock_warning_threshold: int = pydantic.Field(
-        default_factory=_default_lock_warning_threshold
-    )
-
-    # The keyspace subscription string when redis is waiting for lock to be released.
-    _redis_notify_keyspace_events: str = (
-        "K"  # Enable keyspace notifications (target a particular key)
-        "g"  # For generic commands (DEL, EXPIRE, etc)
-        "x"  # For expired events
-        "e"  # For evicted events (i.e. maxmemory exceeded)
-    )
-
-    # These events indicate that a lock is no longer held.
-    _redis_keyspace_lock_release_events: set[bytes] = {
-        b"del",
-        b"expire",
-        b"expired",
-        b"evicted",
-    }
-
-    # Whether keyspace notifications have been enabled.
-    _redis_notify_keyspace_events_enabled: bool = False
-
-    # The logical database number used by the redis client.
-    _redis_db: int = 0
-
-    def _get_required_state_classes(
-        self,
-        target_state_cls: type[BaseState],
-        subclasses: bool = False,
-        required_state_classes: set[type[BaseState]] | None = None,
-    ) -> set[type[BaseState]]:
-        """Recursively determine which states are required to fetch the target state.
-
-        This will always include potentially dirty substates that depend on vars
-        in the target_state_cls.
-
-        Args:
-            target_state_cls: The target state class being fetched.
-            subclasses: Whether to include subclasses of the target state.
-            required_state_classes: Recursive argument tracking state classes that have already been seen.
-
-        Returns:
-            The set of state classes required to fetch the target state.
-        """
-        if required_state_classes is None:
-            required_state_classes = set()
-        # Get the substates if requested.
-        if subclasses:
-            for substate in target_state_cls.get_substates():
-                self._get_required_state_classes(
-                    substate,
-                    subclasses=True,
-                    required_state_classes=required_state_classes,
-                )
-        if target_state_cls in required_state_classes:
-            return required_state_classes
-        required_state_classes.add(target_state_cls)
-
-        # Get dependent substates.
-        for pd_substates in target_state_cls._get_potentially_dirty_states():
-            self._get_required_state_classes(
-                pd_substates,
-                subclasses=False,
-                required_state_classes=required_state_classes,
-            )
-
-        # Get the parent state if it exists.
-        if parent_state := target_state_cls.get_parent_state():
-            self._get_required_state_classes(
-                parent_state,
-                subclasses=False,
-                required_state_classes=required_state_classes,
-            )
-        return required_state_classes
-
-    def _get_populated_states(
-        self,
-        target_state: BaseState,
-        populated_states: dict[str, BaseState] | None = None,
-    ) -> dict[str, BaseState]:
-        """Recursively determine which states from target_state are already fetched.
-
-        Args:
-            target_state: The state to check for populated states.
-            populated_states: Recursive argument tracking states seen in previous calls.
-
-        Returns:
-            A dictionary of state full name to state instance.
-        """
-        if populated_states is None:
-            populated_states = {}
-        if target_state.get_full_name() in populated_states:
-            return populated_states
-        populated_states[target_state.get_full_name()] = target_state
-        for substate in target_state.substates.values():
-            self._get_populated_states(substate, populated_states=populated_states)
-        if target_state.parent_state is not None:
-            self._get_populated_states(
-                target_state.parent_state, populated_states=populated_states
-            )
-        return populated_states
-
-    @override
-    async def get_state(
-        self,
-        token: str,
-        top_level: bool = True,
-        for_state_instance: BaseState | None = None,
-    ) -> BaseState:
-        """Get the state for a token.
-
-        Args:
-            token: The token to get the state for.
-            top_level: If true, return an instance of the top-level state (self.state).
-            for_state_instance: If provided, attach the requested states to this existing state tree.
-
-        Returns:
-            The state for the token.
-
-        Raises:
-            RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
-                requested state was not fetched.
-        """
-        # Split the actual token from the fully qualified substate name.
-        token, state_path = _split_substate_key(token)
-        if state_path:
-            # Get the State class associated with the given path.
-            state_cls = self.state.get_class_substate(state_path)
-        else:
-            raise RuntimeError(
-                f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
-            )
-
-        # Determine which states we already have.
-        flat_state_tree: dict[str, BaseState] = (
-            self._get_populated_states(for_state_instance) if for_state_instance else {}
-        )
-
-        # Determine which states from the tree need to be fetched.
-        required_state_classes = sorted(
-            self._get_required_state_classes(state_cls, subclasses=True)
-            - {type(s) for s in flat_state_tree.values()},
-            key=lambda x: x.get_full_name(),
-        )
-
-        redis_pipeline = self.redis.pipeline()
-        for state_cls in required_state_classes:
-            redis_pipeline.get(_substate_key(token, state_cls))
-
-        for state_cls, redis_state in zip(
-            required_state_classes,
-            await redis_pipeline.execute(),
-            strict=False,
-        ):
-            state = None
-
-            if redis_state is not None:
-                # Deserialize the substate.
-                with contextlib.suppress(StateSchemaMismatchError):
-                    state = BaseState._deserialize(data=redis_state)
-            if state is None:
-                # Key didn't exist or schema mismatch so create a new instance for this token.
-                state = state_cls(
-                    init_substates=False,
-                    _reflex_internal_init=True,
-                )
-            flat_state_tree[state.get_full_name()] = state
-            if state.get_parent_state() is not None:
-                parent_state_name, _dot, state_name = state.get_full_name().rpartition(
-                    "."
-                )
-                parent_state = flat_state_tree.get(parent_state_name)
-                if parent_state is None:
-                    raise RuntimeError(
-                        f"Parent state for {state.get_full_name()} was not found "
-                        "in the state tree, but should have already been fetched. "
-                        "This is a bug",
-                    )
-                parent_state.substates[state_name] = state
-                state.parent_state = parent_state
-
-        # To retain compatibility with previous implementation, by default, we return
-        # the top-level state which should always be fetched or already cached.
-        if top_level:
-            return flat_state_tree[self.state.get_full_name()]
-        return flat_state_tree[state_cls.get_full_name()]
-
-    @override
-    async def set_state(
-        self,
-        token: str,
-        state: BaseState,
-        lock_id: bytes | None = None,
-    ):
-        """Set the state for a token.
-
-        Args:
-            token: The token to set the state for.
-            state: The state to set.
-            lock_id: If provided, the lock_key must be set to this value to set the state.
-
-        Raises:
-            LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
-            RuntimeError: If the state instance doesn't match the state name in the token.
-        """
-        # Check that we're holding the lock.
-        if (
-            lock_id is not None
-            and await self.redis.get(self._lock_key(token)) != lock_id
-        ):
-            raise LockExpiredError(
-                f"Lock expired for token {token} while processing. Consider increasing "
-                f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
-                "or use `@rx.event(background=True)` decorator for long-running tasks."
-            )
-        elif lock_id is not None:
-            time_taken = self.lock_expiration / 1000 - (
-                await self.redis.ttl(self._lock_key(token))
-            )
-            if time_taken > self.lock_warning_threshold / 1000:
-                console.warn(
-                    f"Lock for token {token} was held too long {time_taken=}s, "
-                    f"use `@rx.event(background=True)` decorator for long-running tasks.",
-                    dedupe=True,
-                )
-
-        client_token, substate_name = _split_substate_key(token)
-        # If the substate name on the token doesn't match the instance name, it cannot have a parent.
-        if state.parent_state is not None and state.get_full_name() != substate_name:
-            raise RuntimeError(
-                f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
-            )
-
-        # Recursively set_state on all known substates.
-        tasks = [
-            asyncio.create_task(
-                self.set_state(
-                    _substate_key(client_token, substate),
-                    substate,
-                    lock_id,
-                )
-            )
-            for substate in state.substates.values()
-        ]
-        # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
-        if state._get_was_touched():
-            pickle_state = state._serialize()
-            if pickle_state:
-                await self.redis.set(
-                    _substate_key(client_token, state),
-                    pickle_state,
-                    ex=self.token_expiration,
-                )
-
-        # Wait for substates to be persisted.
-        for t in tasks:
-            await t
-
-    @override
-    @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
-        """Modify the state for a token while holding exclusive lock.
-
-        Args:
-            token: The token to modify the state for.
-
-        Yields:
-            The state for the token.
-        """
-        async with self._lock(token) as lock_id:
-            state = await self.get_state(token)
-            yield state
-            await self.set_state(token, state, lock_id)
-
-    @validator("lock_warning_threshold")
-    @classmethod
-    def validate_lock_warning_threshold(
-        cls, lock_warning_threshold: int, values: dict[str, int]
-    ):
-        """Validate the lock warning threshold.
-
-        Args:
-            lock_warning_threshold: The lock warning threshold.
-            values: The validated attributes.
-
-        Returns:
-            The lock warning threshold.
-
-        Raises:
-            InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
-        """
-        if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
-            raise InvalidLockWarningThresholdError(
-                f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
-            )
-        return lock_warning_threshold
-
-    @staticmethod
-    def _lock_key(token: str) -> bytes:
-        """Get the redis key for a token's lock.
-
-        Args:
-            token: The token to get the lock key for.
-
-        Returns:
-            The redis lock key for the token.
-        """
-        # All substates share the same lock domain, so ignore any substate path suffix.
-        client_token = _split_substate_key(token)[0]
-        return f"{client_token}_lock".encode()
-
-    async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
-        """Try to get a redis lock for a token.
-
-        Args:
-            lock_key: The redis key for the lock.
-            lock_id: The ID of the lock.
-
-        Returns:
-            True if the lock was obtained.
-        """
-        return await self.redis.set(
-            lock_key,
-            lock_id,
-            px=self.lock_expiration,
-            nx=True,  # only set if it doesn't exist
-        )
-
-    async def _get_pubsub_message(
-        self, pubsub: PubSub, timeout: float | None = None
-    ) -> None:
-        """Get lock release events from the pubsub.
-
-        Args:
-            pubsub: The pubsub to get a message from.
-            timeout: Remaining time to wait for a message.
-
-        Returns:
-            The message.
-        """
-        if timeout is None:
-            timeout = self.lock_expiration / 1000.0
-
-        started = time.time()
-        message = await pubsub.get_message(
-            ignore_subscribe_messages=True,
-            timeout=timeout,
-        )
-        if (
-            message is None
-            or message["data"] not in self._redis_keyspace_lock_release_events
-        ):
-            remaining = timeout - (time.time() - started)
-            if remaining <= 0:
-                return
-            await self._get_pubsub_message(pubsub, timeout=remaining)
-
-    async def _enable_keyspace_notifications(self):
-        """Enable keyspace notifications for the redis server.
-
-        Raises:
-            ResponseError: when the keyspace config cannot be set.
-        """
-        if self._redis_notify_keyspace_events_enabled:
-            return
-        # Find out which logical database index is being used.
-        self._redis_db = self.redis.get_connection_kwargs().get("db", self._redis_db)
-
-        try:
-            await self.redis.config_set(
-                "notify-keyspace-events",
-                self._redis_notify_keyspace_events,
-            )
-        except ResponseError:
-            # Some redis servers only allow out-of-band configuration, so ignore errors here.
-            if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
-                raise
-        self._redis_notify_keyspace_events_enabled = True
-
-    async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
-        """Wait for a redis lock to be released via pubsub.
-
-        Coroutine will not return until the lock is obtained.
-
-        Args:
-            lock_key: The redis key for the lock.
-            lock_id: The ID of the lock.
-        """
-        # Enable keyspace notifications for the lock key, so we know when it is available.
-        await self._enable_keyspace_notifications()
-        lock_key_channel = f"__keyspace@{self._redis_db}__:{lock_key.decode()}"
-        async with self.redis.pubsub() as pubsub:
-            await pubsub.psubscribe(lock_key_channel)
-            # wait for the lock to be released
-            while True:
-                # fast path
-                if await self._try_get_lock(lock_key, lock_id):
-                    return
-                # wait for lock events
-                await self._get_pubsub_message(pubsub)
-
-    @contextlib.asynccontextmanager
-    async def _lock(self, token: str):
-        """Obtain a redis lock for a token.
-
-        Args:
-            token: The token to obtain a lock for.
-
-        Yields:
-            The ID of the lock (to be passed to set_state).
-
-        Raises:
-            LockExpiredError: If the lock has expired while processing the event.
-        """
-        lock_key = self._lock_key(token)
-        lock_id = uuid.uuid4().hex.encode()
-
-        if not await self._try_get_lock(lock_key, lock_id):
-            # Missed the fast-path to get lock, subscribe for lock delete/expire events
-            await self._wait_lock(lock_key, lock_id)
-        state_is_locked = True
-
-        try:
-            yield lock_id
-        except LockExpiredError:
-            state_is_locked = False
-            raise
-        finally:
-            if state_is_locked:
-                # only delete our lock
-                await self.redis.delete(lock_key)
-
-    async def close(self):
-        """Explicitly close the redis connection and connection_pool.
-
-        It is necessary in testing scenarios to close between asyncio test cases
-        to avoid having lingering redis connections associated with event loops
-        that will be closed (each test case uses its own event loop).
-
-        Note: Connections will be automatically reopened when needed.
-        """
-        await self.redis.aclose(close_connection_pool=True)
-
-
-def get_state_manager() -> StateManager:
-    """Get the state manager for the app that is currently running.
-
-    Returns:
-        The state manager.
-    """
-    return prerequisites.get_and_validate_app().app.state_manager
-
-
-class MutableProxy(wrapt.ObjectProxy):
-    """A proxy for a mutable object that tracks changes."""
-
-    # Hint for finding the base class of the proxy.
-    __base_proxy__ = "MutableProxy"
-
-    # Methods on wrapped objects which should mark the state as dirty.
-    __mark_dirty_attrs__ = {
-        "add",
-        "append",
-        "clear",
-        "difference_update",
-        "discard",
-        "extend",
-        "insert",
-        "intersection_update",
-        "pop",
-        "popitem",
-        "remove",
-        "reverse",
-        "setdefault",
-        "sort",
-        "symmetric_difference_update",
-        "update",
-    }
-
-    # Methods on wrapped objects might return mutable objects that should be tracked.
-    __wrap_mutable_attrs__ = {
-        "get",
-        "setdefault",
-    }
-
-    # These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
-    __never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
-        pydantic.BaseModel.__dict__
-    )
-
-    # These types will be wrapped in MutableProxy
-    __mutable_types__ = (
-        list,
-        dict,
-        set,
-        Base,
-        DeclarativeBase,
-        BaseModelV2,
-        BaseModelV1,
-    )
-
-    # Dynamically generated classes for tracking dataclass mutations.
-    __dataclass_proxies__: dict[type, type] = {}
-
-    def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
-        """Create a proxy instance for a mutable object that tracks changes.
-
-        Args:
-            wrapped: The object to proxy.
-            *args: Other args passed to MutableProxy (ignored).
-            **kwargs: Other kwargs passed to MutableProxy (ignored).
-
-        Returns:
-            The proxy instance.
-        """
-        if dataclasses.is_dataclass(wrapped):
-            wrapped_cls = type(wrapped)
-            wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
-            # Find the associated class
-            if wrapper_cls_name not in cls.__dataclass_proxies__:
-                # Create a new class that has the __dataclass_fields__ defined
-                cls.__dataclass_proxies__[wrapper_cls_name] = type(
-                    wrapper_cls_name,
-                    (cls,),
-                    {
-                        dataclasses._FIELDS: getattr(  # pyright: ignore [reportAttributeAccessIssue]
-                            wrapped_cls,
-                            dataclasses._FIELDS,  # pyright: ignore [reportAttributeAccessIssue]
-                        ),
-                    },
-                )
-            cls = cls.__dataclass_proxies__[wrapper_cls_name]
-        return super().__new__(cls)
-
-    def __init__(self, wrapped: Any, state: BaseState, field_name: str):
-        """Create a proxy for a mutable object that tracks changes.
-
-        Args:
-            wrapped: The object to proxy.
-            state: The state to mark dirty when the object is changed.
-            field_name: The name of the field on the state associated with the
-                wrapped object.
-        """
-        super().__init__(wrapped)
-        self._self_state = state
-        self._self_field_name = field_name
-
-    def __repr__(self) -> str:
-        """Get the representation of the wrapped object.
-
-        Returns:
-            The representation of the wrapped object.
-        """
-        return f"{type(self).__name__}({self.__wrapped__})"
-
-    def _mark_dirty(
-        self,
-        wrapped: Callable | None = None,
-        instance: BaseState | None = None,
-        args: tuple = (),
-        kwargs: dict | None = None,
-    ) -> Any:
-        """Mark the state as dirty, then call a wrapped function.
-
-        Intended for use with `FunctionWrapper` from the `wrapt` library.
-
-        Args:
-            wrapped: The wrapped function.
-            instance: The instance of the wrapped function.
-            args: The args for the wrapped function.
-            kwargs: The kwargs for the wrapped function.
-
-        Returns:
-            The result of the wrapped function.
-        """
-        self._self_state.dirty_vars.add(self._self_field_name)
-        self._self_state._mark_dirty()
-        if wrapped is not None:
-            return wrapped(*args, **(kwargs or {}))
-
-    @classmethod
-    def _is_mutable_type(cls, value: Any) -> bool:
-        """Check if a value is of a mutable type and should be wrapped.
-
-        Args:
-            value: The value to check.
-
-        Returns:
-            Whether the value is of a mutable type.
-        """
-        return isinstance(value, cls.__mutable_types__) or (
-            dataclasses.is_dataclass(value) and not isinstance(value, Var)
-        )
-
-    @staticmethod
-    def _is_called_from_dataclasses_internal() -> bool:
-        """Check if the current function is called from dataclasses helper.
-
-        Returns:
-            Whether the current function is called from dataclasses internal code.
-        """
-        # Walk up the stack a bit to see if we are called from dataclasses
-        # internal code, for example `asdict` or `astuple`.
-        frame = inspect.currentframe()
-        for _ in range(5):
-            # Why not `inspect.stack()` -- this is much faster!
-            if not (frame := frame and frame.f_back):
-                break
-            if inspect.getfile(frame) == dataclasses.__file__:
-                return True
-        return False
-
-    def _wrap_recursive(self, value: Any) -> Any:
-        """Wrap a value recursively if it is mutable.
-
-        Args:
-            value: The value to wrap.
-
-        Returns:
-            The wrapped value.
-        """
-        # When called from dataclasses internal code, return the unwrapped value
-        if self._is_called_from_dataclasses_internal():
-            return value
-        # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
-        if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
-            base_cls = globals()[self.__base_proxy__]
-            return base_cls(
-                wrapped=value,
-                state=self._self_state,
-                field_name=self._self_field_name,
-            )
-        return value
-
-    def _wrap_recursive_decorator(
-        self, wrapped: Callable, instance: BaseState, args: list, kwargs: dict
-    ) -> Any:
-        """Wrap a function that returns a possibly mutable value.
-
-        Intended for use with `FunctionWrapper` from the `wrapt` library.
-
-        Args:
-            wrapped: The wrapped function.
-            instance: The instance of the wrapped function.
-            args: The args for the wrapped function.
-            kwargs: The kwargs for the wrapped function.
-
-        Returns:
-            The result of the wrapped function (possibly wrapped in a MutableProxy).
-        """
-        return self._wrap_recursive(wrapped(*args, **kwargs))
-
-    def __getattr__(self, __name: str) -> Any:
-        """Get the attribute on the proxied object and return a proxy if mutable.
-
-        Args:
-            __name: The name of the attribute.
-
-        Returns:
-            The attribute value.
-        """
-        value = super().__getattr__(__name)
-
-        if callable(value):
-            if __name in self.__mark_dirty_attrs__:
-                # Wrap special callables, like "append", which should mark state dirty.
-                value = wrapt.FunctionWrapper(value, self._mark_dirty)
-
-            if __name in self.__wrap_mutable_attrs__:
-                # Wrap methods that may return mutable objects tied to the state.
-                value = wrapt.FunctionWrapper(
-                    value,
-                    self._wrap_recursive_decorator,
-                )
-
-            if (
-                isinstance(self.__wrapped__, Base)
-                and __name not in self.__never_wrap_base_attrs__
-                and hasattr(value, "__func__")
-            ):
-                # Wrap methods called on Base subclasses, which might do _anything_
-                return wrapt.FunctionWrapper(
-                    functools.partial(value.__func__, self),  # pyright: ignore [reportFunctionMemberAccess]
-                    self._wrap_recursive_decorator,
-                )
-
-        if self._is_mutable_type(value) and __name not in (
-            "__wrapped__",
-            "_self_state",
-            "__dict__",
-        ):
-            # Recursively wrap mutable attribute values retrieved through this proxy.
-            return self._wrap_recursive(value)
-
-        return value
-
-    def __getitem__(self, key: Any) -> Any:
-        """Get the item on the proxied object and return a proxy if mutable.
-
-        Args:
-            key: The key of the item.
-
-        Returns:
-            The item value.
-        """
-        value = super().__getitem__(key)
-        if isinstance(key, slice) and isinstance(value, list):
-            return [self._wrap_recursive(item) for item in value]
-        # Recursively wrap mutable items retrieved through this proxy.
-        return self._wrap_recursive(value)
-
-    def __iter__(self) -> Any:
-        """Iterate over the proxied object and return a proxy if mutable.
-
-        Yields:
-            Each item value (possibly wrapped in MutableProxy).
-        """
-        for value in super().__iter__():
-            # Recursively wrap mutable items retrieved through this proxy.
-            yield self._wrap_recursive(value)
-
-    def __delattr__(self, name: str):
-        """Delete the attribute on the proxied object and mark state dirty.
-
-        Args:
-            name: The name of the attribute.
-        """
-        self._mark_dirty(super().__delattr__, args=(name,))
-
-    def __delitem__(self, key: str):
-        """Delete the item on the proxied object and mark state dirty.
-
-        Args:
-            key: The key of the item.
-        """
-        self._mark_dirty(super().__delitem__, args=(key,))
-
-    def __setitem__(self, key: str, value: Any):
-        """Set the item on the proxied object and mark state dirty.
-
-        Args:
-            key: The key of the item.
-            value: The value of the item.
-        """
-        self._mark_dirty(super().__setitem__, args=(key, value))
-
-    def __setattr__(self, name: str, value: Any):
-        """Set the attribute on the proxied object and mark state dirty.
-
-        If the attribute starts with "_self_", then the state is NOT marked
-        dirty as these are internal proxy attributes.
-
-        Args:
-            name: The name of the attribute.
-            value: The value of the attribute.
-        """
-        if name.startswith("_self_"):
-            # Special case attributes of the proxy itself, not applied to the wrapped object.
-            super().__setattr__(name, value)
-            return
-        self._mark_dirty(super().__setattr__, args=(name, value))
-
-    def __copy__(self) -> Any:
-        """Return a copy of the proxy.
-
-        Returns:
-            A copy of the wrapped object, unconnected to the proxy.
-        """
-        return copy.copy(self.__wrapped__)
-
-    def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Any:
-        """Return a deepcopy of the proxy.
-
-        Args:
-            memo: The memo dict to use for the deepcopy.
-
-        Returns:
-            A deepcopy of the wrapped object, unconnected to the proxy.
-        """
-        return copy.deepcopy(self.__wrapped__, memo=memo)
-
-    def __reduce_ex__(self, protocol_version: SupportsIndex):
-        """Get the state for redis serialization.
-
-        This method is called by cloudpickle to serialize the object.
-
-        It explicitly serializes the wrapped object, stripping off the mutable proxy.
-
-        Args:
-            protocol_version: The protocol version.
-
-        Returns:
-            Tuple of (wrapped class, empty args, class __getstate__)
-        """
-        return self.__wrapped__.__reduce_ex__(protocol_version)
-
-
-@serializer
-def serialize_mutable_proxy(mp: MutableProxy):
-    """Return the wrapped value of a MutableProxy.
-
-    Args:
-        mp: The MutableProxy to serialize.
-
-    Returns:
-        The wrapped object.
-    """
-    return mp.__wrapped__
-
-
-_orig_json_encoder_default = json.JSONEncoder.default
-
-
-def _json_encoder_default_wrapper(self: json.JSONEncoder, o: Any) -> Any:
-    """Wrap JSONEncoder.default to handle MutableProxy objects.
-
-    Args:
-        self: the JSONEncoder instance.
-        o: the object to serialize.
-
-    Returns:
-        A JSON-able object.
-    """
-    try:
-        return o.__wrapped__
-    except AttributeError:
-        pass
-    return _orig_json_encoder_default(self, o)
-
-
-json.JSONEncoder.default = _json_encoder_default_wrapper
-
-
-class ImmutableMutableProxy(MutableProxy):
-    """A proxy for a mutable object that tracks changes.
-
-    This wrapper comes from StateProxy, and will raise an exception if an attempt is made
-    to modify the wrapped object when the StateProxy is immutable.
-    """
-
-    # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
-    __base_proxy__ = "ImmutableMutableProxy"
-
-    def _mark_dirty(
-        self,
-        wrapped: Callable | None = None,
-        instance: BaseState | None = None,
-        args: tuple = (),
-        kwargs: dict | None = None,
-    ) -> Any:
-        """Raise an exception when an attempt is made to modify the object.
-
-        Intended for use with `FunctionWrapper` from the `wrapt` library.
-
-        Args:
-            wrapped: The wrapped function.
-            instance: The instance of the wrapped function.
-            args: The args for the wrapped function.
-            kwargs: The kwargs for the wrapped function.
-
-        Returns:
-            The result of the wrapped function.
-
-        Raises:
-            ImmutableStateError: if the StateProxy is not mutable.
-        """
-        if not self._self_state._is_mutable():
-            raise ImmutableStateError(
-                "Background task StateProxy is immutable outside of a context "
-                "manager. Use `async with self` to modify state."
-            )
-        return super()._mark_dirty(
-            wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
-        )
-
-
-def code_uses_state_contexts(javascript_code: str) -> bool:
-    """Check if the rendered Javascript uses state contexts.
-
-    Args:
-        javascript_code: The Javascript code to check.
-
-    Returns:
-        True if the code attempts to access a member of StateContexts.
-    """
-    return bool("useContext(StateContexts" in javascript_code)
-
-
-def reload_state_module(
-    module: str,
-    state: type[BaseState] = State,
-) -> None:
-    """Reset rx.State subclasses to avoid conflict when reloading.
-
-    Args:
-        module: The module to reload.
-        state: Recursive argument for the state class to reload.
-
-    """
-    # Clean out all potentially dirty states of reloaded modules.
-    for pd_state in tuple(state._potentially_dirty_states):
-        with contextlib.suppress(ValueError):
-            if (
-                state.get_root_state().get_class_substate(pd_state).__module__ == module
-                and module is not None
-            ):
-                state._potentially_dirty_states.remove(pd_state)
-    for subclass in tuple(state.class_subclasses):
-        reload_state_module(module=module, state=subclass)
-        if subclass.__module__ == module and module is not None:
-            all_base_state_classes.pop(subclass.get_full_name(), None)
-            state.class_subclasses.remove(subclass)
-            state._always_dirty_substates.discard(subclass.get_name())
-            state._var_dependencies = {}
-            state._init_var_dependency_dicts()
-    state.get_class_substate.cache_clear()
+from reflex.istate.manager import LockExpiredError as LockExpiredError  # noqa: E402
+from reflex.istate.manager import StateManager as StateManager  # noqa: E402
+from reflex.istate.manager import StateManagerDisk as StateManagerDisk  # noqa: E402
+from reflex.istate.manager import StateManagerMemory as StateManagerMemory  # noqa: E402
+from reflex.istate.manager import StateManagerRedis as StateManagerRedis  # noqa: E402
+from reflex.istate.manager import get_state_manager as get_state_manager  # noqa: E402
+from reflex.istate.manager import (  # noqa: E402
+    reset_disk_state_manager as reset_disk_state_manager,
+)

+ 17 - 10
tests/units/test_state.py

@@ -27,18 +27,20 @@ from reflex.app import App
 from reflex.base import Base
 from reflex.constants import CompileVars, RouteVar, SocketEvent
 from reflex.event import Event, EventHandler
+from reflex.istate.manager import (
+    LockExpiredError,
+    StateManager,
+    StateManagerDisk,
+    StateManagerMemory,
+    StateManagerRedis,
+)
 from reflex.state import (
     BaseState,
     ImmutableStateError,
-    LockExpiredError,
     MutableProxy,
     OnLoadInternalState,
     RouterData,
     State,
-    StateManager,
-    StateManagerDisk,
-    StateManagerMemory,
-    StateManagerRedis,
     StateProxy,
     StateUpdate,
     _substate_key,
@@ -1777,7 +1779,7 @@ def substate_token_redis(state_manager_redis, token):
 
 @pytest.mark.asyncio
 async def test_state_manager_lock_expire(
-    state_manager_redis: StateManager, token: str, substate_token_redis: str
+    state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
 ):
     """Test that the state manager lock expires and raises exception exiting context.
 
@@ -1799,7 +1801,7 @@ async def test_state_manager_lock_expire(
 
 @pytest.mark.asyncio
 async def test_state_manager_lock_expire_contend(
-    state_manager_redis: StateManager, token: str, substate_token_redis: str
+    state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
 ):
     """Test that the state manager lock expires and queued waiters proceed.
 
@@ -1844,7 +1846,10 @@ async def test_state_manager_lock_expire_contend(
 
 @pytest.mark.asyncio
 async def test_state_manager_lock_warning_threshold_contend(
-    state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
+    state_manager_redis: StateManagerRedis,
+    token: str,
+    substate_token_redis: str,
+    mocker,
 ):
     """Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
 
@@ -3354,7 +3359,8 @@ config = rx.Config(
     with chdir(proj_root):
         # reload config for each parameter to avoid stale values
         reflex.config.get_config(reload=True)
-        from reflex.state import State, StateManager
+        from reflex.istate.manager import StateManager
+        from reflex.state import State
 
         state_manager = StateManager.create(state=State)
         assert state_manager.lock_expiration == expected_values[0]  # pyright: ignore [reportAttributeAccessIssue]
@@ -3392,7 +3398,8 @@ config = rx.Config(
     with chdir(proj_root):
         # reload config for each parameter to avoid stale values
         reflex.config.get_config(reload=True)
-        from reflex.state import State, StateManager
+        from reflex.istate.manager import StateManager
+        from reflex.state import State
 
         with pytest.raises(InvalidLockWarningThresholdError):
             StateManager.create(state=State)