Procházet zdrojové kódy

let users pick state manager mode (#4041)

Thomas Brandého před 7 měsíci
rodič
revize
6f586c8b8f

+ 16 - 0
reflex/config.py

@@ -9,6 +9,8 @@ import urllib.parse
 from pathlib import Path
 from typing import Any, Dict, List, Optional, Set, Union
 
+from reflex.utils.exceptions import ConfigError
+
 try:
     import pydantic.v1 as pydantic
 except ModuleNotFoundError:
@@ -220,6 +222,9 @@ class Config(Base):
     # Number of gunicorn workers from user
     gunicorn_workers: Optional[int] = None
 
+    # Indicate which type of state manager to use
+    state_manager_mode: constants.StateManagerMode = constants.StateManagerMode.DISK
+
     # Maximum expiration lock time for redis state manager
     redis_lock_expiration: int = constants.Expiration.LOCK
 
@@ -235,6 +240,9 @@ class Config(Base):
         Args:
             *args: The args to pass to the Pydantic init method.
             **kwargs: The kwargs to pass to the Pydantic init method.
+
+        Raises:
+            ConfigError: If some values in the config are invalid.
         """
         super().__init__(*args, **kwargs)
 
@@ -248,6 +256,14 @@ class Config(Base):
         self._non_default_attributes.update(kwargs)
         self._replace_defaults(**kwargs)
 
+        if (
+            self.state_manager_mode == constants.StateManagerMode.REDIS
+            and not self.redis_url
+        ):
+            raise ConfigError(
+                "REDIS_URL is required when using the redis state manager."
+            )
+
     @property
     def module(self) -> str:
         """Get the module name of the app.

+ 2 - 0
reflex/constants/__init__.py

@@ -63,6 +63,7 @@ from .route import (
     RouteRegex,
     RouteVar,
 )
+from .state import StateManagerMode
 from .style import Tailwind
 
 __ALL__ = [
@@ -115,6 +116,7 @@ __ALL__ = [
     SETTER_PREFIX,
     SKIP_COMPILE_ENV_VAR,
     SocketEvent,
+    StateManagerMode,
     Tailwind,
     Templates,
     CompileVars,

+ 11 - 0
reflex/constants/state.py

@@ -0,0 +1,11 @@
+"""State-related constants."""
+
+from enum import Enum
+
+
+class StateManagerMode(str, Enum):
+    """State manager constants."""
+
+    DISK = "disk"
+    MEMORY = "memory"
+    REDIS = "redis"

+ 25 - 14
reflex/state.py

@@ -76,6 +76,7 @@ from reflex.utils.exceptions import (
     DynamicRouteArgShadowsStateVar,
     EventHandlerShadowsBuiltInStateMethod,
     ImmutableStateError,
+    InvalidStateManagerMode,
     LockExpiredError,
     SetUndefinedStateVarError,
     StateSchemaMismatchError,
@@ -2514,20 +2515,30 @@ class StateManager(Base, ABC):
         Args:
             state: The state class to use.
 
-        Returns:
-            The state manager (either disk or redis).
-        """
-        redis = prerequisites.get_redis()
-        if redis is not None:
-            # make sure expiration values are obtained only from the config object on creation
-            config = get_config()
-            return StateManagerRedis(
-                state=state,
-                redis=redis,
-                token_expiration=config.redis_token_expiration,
-                lock_expiration=config.redis_lock_expiration,
-            )
-        return StateManagerDisk(state=state)
+        Raises:
+            InvalidStateManagerMode: If the state manager mode is invalid.
+
+        Returns:
+            The state manager (either disk, memory or redis).
+        """
+        config = get_config()
+        if config.state_manager_mode == constants.StateManagerMode.DISK:
+            return StateManagerMemory(state=state)
+        if config.state_manager_mode == constants.StateManagerMode.MEMORY:
+            return StateManagerDisk(state=state)
+        if config.state_manager_mode == constants.StateManagerMode.REDIS:
+            redis = prerequisites.get_redis()
+            if redis is not None:
+                # make sure expiration values are obtained only from the config object on creation
+                return StateManagerRedis(
+                    state=state,
+                    redis=redis,
+                    token_expiration=config.redis_token_expiration,
+                    lock_expiration=config.redis_lock_expiration,
+                )
+        raise InvalidStateManagerMode(
+            f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
+        )
 
     @abstractmethod
     async def get_state(self, token: str) -> BaseState:

+ 8 - 0
reflex/utils/exceptions.py

@@ -5,6 +5,14 @@ class ReflexError(Exception):
     """Base exception for all Reflex exceptions."""
 
 
+class ConfigError(ReflexError):
+    """Custom exception for config related errors."""
+
+
+class InvalidStateManagerMode(ReflexError, ValueError):
+    """Raised when an invalid state manager mode is provided."""
+
+
 class ReflexRuntimeError(ReflexError, RuntimeError):
     """Custom RuntimeError for Reflex."""
 

+ 1 - 0
tests/units/test_state.py

@@ -3201,6 +3201,7 @@ import reflex as rx
 config = rx.Config(
     app_name="project1",
     redis_url="redis://localhost:6379",
+    state_manager_mode="redis",
     {config_items}
 )
 """