1
0
Эх сурвалжийг харах

implement disk state manager (#3826)

* implement disk state manager

* put states inside of a folder

* return root state all the time

* factor out code

* add docs for token expiration

* cache states directory

* call absolute on web directory

* change dir to app path when testing the backend

* remove accidental 🥒

* test disk for now

* modify schema

* only serialize specific stuff

* fix issue in types

* what is a kilometer

* create folder if it doesn't exist in write

* this code hates me

* check if the file isn't empty

* add try except clause

* add check for directory again
Khaleel Al-Adhami 9 сар өмнө
parent
commit
629850162a

+ 2 - 0
reflex/constants/base.py

@@ -45,6 +45,8 @@ class Dirs(SimpleNamespace):
     REFLEX_JSON = "reflex.json"
     # The name of the postcss config file.
     POSTCSS_JS = "postcss.config.js"
+    # The name of the states directory.
+    STATES = "states"
 
 
 class Reflex(SimpleNamespace):

+ 242 - 11
reflex/state.py

@@ -11,6 +11,7 @@ import os
 import uuid
 from abc import ABC, abstractmethod
 from collections import defaultdict
+from pathlib import Path
 from types import FunctionType, MethodType
 from typing import (
     TYPE_CHECKING,
@@ -23,6 +24,7 @@ from typing import (
     Optional,
     Sequence,
     Set,
+    Tuple,
     Type,
     Union,
     cast,
@@ -52,7 +54,7 @@ from reflex.event import (
     EventSpec,
     fix_events,
 )
-from reflex.utils import console, format, prerequisites, types
+from reflex.utils import console, format, path_ops, prerequisites, types
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
 from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import SerializedType, serialize, serializer
@@ -2339,7 +2341,7 @@ class StateManager(Base, ABC):
                 token_expiration=config.redis_token_expiration,
                 lock_expiration=config.redis_lock_expiration,
             )
-        return StateManagerMemory(state=state)
+        return StateManagerDisk(state=state)
 
     @abstractmethod
     async def get_state(self, token: str) -> BaseState:
@@ -2446,6 +2448,244 @@ class StateManagerMemory(StateManager):
             await self.set_state(token, state)
 
 
+def _default_token_expiration() -> int:
+    """Get the default token expiration time.
+
+    Returns:
+        The default token expiration time.
+    """
+    return get_config().redis_token_expiration
+
+
+def state_to_schema(
+    state: BaseState,
+) -> List[
+    Tuple[
+        str,
+        str,
+        Any,
+        Union[bool, None],
+    ]
+]:
+    """Convert a state to a schema.
+
+    Args:
+        state: The state to convert to a schema.
+
+    Returns:
+        The schema.
+    """
+    return list(
+        sorted(
+            (
+                field_name,
+                model_field.name,
+                model_field.type_,
+                (
+                    model_field.required
+                    if isinstance(model_field.required, bool)
+                    else None
+                ),
+            )
+            for field_name, model_field in state.__fields__.items()
+        )
+    )
+
+
+class StateManagerDisk(StateManager):
+    """A state manager that stores states in memory."""
+
+    # The mapping of client ids to states.
+    states: Dict[str, BaseState] = {}
+
+    # The mutex ensures the dict of mutexes is updated exclusively
+    _state_manager_lock = asyncio.Lock()
+
+    # The dict of mutexes for each client
+    _states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
+
+    # The token expiration time (s).
+    token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
+
+    class Config:
+        """The Pydantic config."""
+
+        fields = {
+            "_states_locks": {"exclude": True},
+        }
+        keep_untouched = (functools.cached_property,)
+
+    def __init__(self, state: Type[BaseState]):
+        """Create a new state manager.
+
+        Args:
+            state: The state class to use.
+        """
+        super().__init__(state=state)
+
+        path_ops.mkdir(self.states_directory)
+
+        self._purge_expired_states()
+
+    @functools.cached_property
+    def states_directory(self) -> Path:
+        """Get the states directory.
+
+        Returns:
+            The states directory.
+        """
+        return prerequisites.get_web_dir() / constants.Dirs.STATES
+
+    def _purge_expired_states(self):
+        """Purge expired states from the disk."""
+        import time
+
+        for path in path_ops.ls(self.states_directory):
+            # check path is a pickle file
+            if path.suffix != ".pkl":
+                continue
+
+            # load last edited field from file
+            last_edited = path.stat().st_mtime
+
+            # check if the file is older than the token expiration time
+            if time.time() - last_edited > self.token_expiration:
+                # remove the file
+                path.unlink()
+
+    def token_path(self, token: str) -> Path:
+        """Get the path for a token.
+
+        Args:
+            token: The token to get the path for.
+
+        Returns:
+            The path for the token.
+        """
+        return (self.states_directory / f"{token}.pkl").absolute()
+
+    async def load_state(self, token: str, root_state: BaseState) -> BaseState:
+        """Load a state object based on the provided token.
+
+        Args:
+            token: The token used to identify the state object.
+            root_state: The root state object.
+
+        Returns:
+            The loaded state object.
+        """
+        if token in self.states:
+            return self.states[token]
+
+        client_token, substate_address = _split_substate_key(token)
+
+        token_path = self.token_path(token)
+
+        if token_path.exists():
+            try:
+                with token_path.open(mode="rb") as file:
+                    (substate_schema, substate) = dill.load(file)
+                if substate_schema == state_to_schema(substate):
+                    await self.populate_substates(client_token, substate, root_state)
+                    return substate
+            except Exception:
+                pass
+
+        return root_state.get_substate(substate_address.split(".")[1:])
+
+    async def populate_substates(
+        self, client_token: str, state: BaseState, root_state: BaseState
+    ):
+        """Populate the substates of a state object.
+
+        Args:
+            client_token: The client token.
+            state: The state object to populate.
+            root_state: The root state object.
+        """
+        for substate in state.get_substates():
+            substate_token = _substate_key(client_token, substate)
+
+            substate = await self.load_state(substate_token, root_state)
+
+            state.substates[substate.get_name()] = substate
+            substate.parent_state = state
+
+    @override
+    async def get_state(
+        self,
+        token: str,
+    ) -> BaseState:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        client_token, substate_address = _split_substate_key(token)
+
+        root_state_token = _substate_key(client_token, substate_address.split(".")[0])
+
+        return await self.load_state(
+            root_state_token, self.state(_reflex_internal_init=True)
+        )
+
+    async def set_state_for_substate(self, client_token: str, substate: BaseState):
+        """Set the state for a substate.
+
+        Args:
+            client_token: The client token.
+            substate: The substate to set.
+        """
+        substate_token = _substate_key(client_token, substate)
+
+        self.states[substate_token] = substate
+
+        state_dilled = dill.dumps((state_to_schema(substate), substate), byref=True)
+        if not self.states_directory.exists():
+            self.states_directory.mkdir(parents=True, exist_ok=True)
+        self.token_path(substate_token).write_bytes(state_dilled)
+
+        for substate_substate in substate.substates.values():
+            await self.set_state_for_substate(client_token, substate_substate)
+
+    @override
+    async def set_state(self, token: str, state: BaseState):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+        """
+        client_token, substate = _split_substate_key(token)
+        await self.set_state_for_substate(client_token, state)
+
+    @override
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        # Memory state manager ignores the substate suffix and always returns the top-level state.
+        client_token, substate = _split_substate_key(token)
+        if client_token not in self._states_locks:
+            async with self._state_manager_lock:
+                if client_token not in self._states_locks:
+                    self._states_locks[client_token] = asyncio.Lock()
+
+        async with self._states_locks[client_token]:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state)
+
+
 # Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
 if not isinstance(State.validate.__func__, FunctionType):
     cython_function_or_method = type(State.validate.__func__)
@@ -2474,15 +2714,6 @@ def _default_lock_expiration() -> int:
     return get_config().redis_lock_expiration
 
 
-def _default_token_expiration() -> int:
-    """Get the default token expiration time.
-
-    Returns:
-        The default token expiration time.
-    """
-    return get_config().redis_token_expiration
-
-
 class StateManagerRedis(StateManager):
     """A state manager that stores states in redis."""
 

+ 8 - 3
reflex/testing.py

@@ -45,6 +45,8 @@ import reflex.utils.prerequisites
 import reflex.utils.processes
 from reflex.state import (
     BaseState,
+    StateManager,
+    StateManagerDisk,
     StateManagerMemory,
     StateManagerRedis,
     reload_state_module,
@@ -126,7 +128,7 @@ class AppHarness:
     frontend_output_thread: Optional[threading.Thread] = None
     backend_thread: Optional[threading.Thread] = None
     backend: Optional[uvicorn.Server] = None
-    state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
+    state_manager: Optional[StateManager] = None
     _frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
     _decorated_pages: list = dataclasses.field(default_factory=list)
 
@@ -290,6 +292,8 @@ class AppHarness:
         if isinstance(self.app_instance._state_manager, StateManagerRedis):
             # Create our own redis connection for testing.
             self.state_manager = StateManagerRedis.create(self.app_instance.state)
+        elif isinstance(self.app_instance._state_manager, StateManagerDisk):
+            self.state_manager = StateManagerDisk.create(self.app_instance.state)
         else:
             self.state_manager = self.app_instance._state_manager
 
@@ -327,7 +331,8 @@ class AppHarness:
             )
         )
         self.backend.shutdown = self._get_backend_shutdown_handler()
-        self.backend_thread = threading.Thread(target=self.backend.run)
+        with chdir(self.app_path):
+            self.backend_thread = threading.Thread(target=self.backend.run)
         self.backend_thread.start()
 
     async def _reset_backend_state_manager(self):
@@ -787,7 +792,7 @@ class AppHarness:
             raise RuntimeError("App is not running.")
         state_manager = self.app_instance.state_manager
         assert isinstance(
-            state_manager, StateManagerMemory
+            state_manager, (StateManagerMemory, StateManagerDisk)
         ), "Only works with memory state manager"
         if not self._poll_for(
             target=lambda: state_manager.states,

+ 12 - 0
reflex/utils/path_ops.py

@@ -81,6 +81,18 @@ def mkdir(path: str | Path):
     Path(path).mkdir(parents=True, exist_ok=True)
 
 
+def ls(path: str | Path) -> list[Path]:
+    """List the contents of a directory.
+
+    Args:
+        path: The path to the directory.
+
+    Returns:
+        A list of paths to the contents of the directory.
+    """
+    return list(Path(path).iterdir())
+
+
 def ln(src: str | Path, dest: str | Path, overwrite: bool = False) -> bool:
     """Create a symbolic link.
 

+ 4 - 1
tests/test_app.py

@@ -42,6 +42,7 @@ from reflex.state import (
     OnLoadInternalState,
     RouterData,
     State,
+    StateManagerDisk,
     StateManagerMemory,
     StateManagerRedis,
     StateUpdate,
@@ -1395,7 +1396,9 @@ def test_app_state_manager():
         app.state_manager
     app._enable_state()
     assert app.state_manager is not None
-    assert isinstance(app.state_manager, (StateManagerMemory, StateManagerRedis))
+    assert isinstance(
+        app.state_manager, (StateManagerMemory, StateManagerDisk, StateManagerRedis)
+    )
 
 
 def test_generate_component():

+ 22 - 9
tests/test_state.py

@@ -31,6 +31,7 @@ from reflex.state import (
     RouterData,
     State,
     StateManager,
+    StateManagerDisk,
     StateManagerMemory,
     StateManagerRedis,
     StateProxy,
@@ -1586,7 +1587,7 @@ async def test_state_with_invalid_yield(capsys, mock_app):
     assert "must only return/yield: None, Events or other EventHandlers" in captured.out
 
 
-@pytest.fixture(scope="function", params=["in_process", "redis"])
+@pytest.fixture(scope="function", params=["in_process", "disk", "redis"])
 def state_manager(request) -> Generator[StateManager, None, None]:
     """Instance of state manager parametrized for redis and in-process.
 
@@ -1600,8 +1601,11 @@ def state_manager(request) -> Generator[StateManager, None, None]:
     if request.param == "redis":
         if not isinstance(state_manager, StateManagerRedis):
             pytest.skip("Test requires redis")
-    else:
+    elif request.param == "disk":
         # explicitly NOT using redis
+        state_manager = StateManagerDisk(state=TestState)
+        assert not state_manager._states_locks
+    else:
         state_manager = StateManagerMemory(state=TestState)
         assert not state_manager._states_locks
 
@@ -1639,7 +1643,7 @@ async def test_state_manager_modify_state(
     async with state_manager.modify_state(substate_token) as state:
         if isinstance(state_manager, StateManagerRedis):
             assert await state_manager.redis.get(f"{token}_lock")
-        elif isinstance(state_manager, StateManagerMemory):
+        elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
             assert token in state_manager._states_locks
             assert state_manager._states_locks[token].locked()
         # Should be able to write proxy objects inside mutables
@@ -1649,11 +1653,11 @@ async def test_state_manager_modify_state(
     # lock should be dropped after exiting the context
     if isinstance(state_manager, StateManagerRedis):
         assert (await state_manager.redis.get(f"{token}_lock")) is None
-    elif isinstance(state_manager, StateManagerMemory):
+    elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
         assert not state_manager._states_locks[token].locked()
 
         # separate instances should NOT share locks
-        sm2 = StateManagerMemory(state=TestState)
+        sm2 = state_manager.__class__(state=TestState)
         assert sm2._state_manager_lock is state_manager._state_manager_lock
         assert not sm2._states_locks
         if state_manager._states_locks:
@@ -1691,7 +1695,7 @@ async def test_state_manager_contend(
 
     if isinstance(state_manager, StateManagerRedis):
         assert (await state_manager.redis.get(f"{token}_lock")) is None
-    elif isinstance(state_manager, StateManagerMemory):
+    elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
         assert token in state_manager._states_locks
         assert not state_manager._states_locks[token].locked()
 
@@ -1831,7 +1835,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     assert child_state is not None
     parent_state = child_state.parent_state
     assert parent_state is not None
-    if isinstance(mock_app.state_manager, StateManagerMemory):
+    if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
         mock_app.state_manager.states[parent_state.router.session.client_token] = (
             parent_state
         )
@@ -1874,7 +1878,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
             # For in-process store, only one instance of the state exists
             assert sp.__wrapped__ is grandchild_state
         else:
-            # When redis is used, a new+updated instance is assigned to the proxy
+            # When redis or disk is used, a new+updated instance is assigned to the proxy
             assert sp.__wrapped__ is not grandchild_state
         sp.value2 = "42"
     assert not sp._self_mutable  # proxy is not mutable after exiting context
@@ -2837,7 +2841,7 @@ async def test_get_state(mock_app: rx.App, token: str):
         _substate_key(token, ChildState2)
     )
     assert isinstance(test_state, TestState)
-    if isinstance(mock_app.state_manager, StateManagerMemory):
+    if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
         # All substates are available
         assert tuple(sorted(test_state.substates)) == (
             ChildState.get_name(),
@@ -2916,6 +2920,15 @@ async def test_get_state(mock_app: rx.App, token: str):
             ChildState2.get_name(),
             ChildState3.get_name(),
         )
+    elif isinstance(mock_app.state_manager, StateManagerDisk):
+        # On disk, it's a new instance
+        assert new_test_state is not test_state
+        # All substates are available
+        assert tuple(sorted(new_test_state.substates)) == (
+            ChildState.get_name(),
+            ChildState2.get_name(),
+            ChildState3.get_name(),
+        )
     else:
         # With redis, we get a whole new instance
         assert new_test_state is not test_state