Browse Source

[REF-3056] Config knob for redis StateManager expiration times (#3523)

Elijah Ahianyo 10 months ago
parent
commit
f037df0977
3 changed files with 62 additions and 3 deletions
  1. 6 0
      reflex/config.py
  2. 12 3
      reflex/state.py
  3. 44 0
      tests/test_state.py

+ 6 - 0
reflex/config.py

@@ -219,6 +219,12 @@ class Config(Base):
     # Number of gunicorn workers from user
     gunicorn_workers: Optional[int] = None
 
+    # Maximum expiration lock time for redis state manager
+    redis_lock_expiration: int = constants.Expiration.LOCK
+
+    # Token expiration time for redis state manager
+    redis_token_expiration: int = constants.Expiration.TOKEN
+
     # Attributes that were explicitly set by the user.
     _non_default_attributes: Set[str] = pydantic.PrivateAttr(set())
 

+ 12 - 3
reflex/state.py

@@ -40,6 +40,7 @@ from redis.exceptions import ResponseError
 
 from reflex import constants
 from reflex.base import Base
+from reflex.config import get_config
 from reflex.event import (
     BACKGROUND_TASK_MARKER,
     Event,
@@ -60,6 +61,7 @@ if TYPE_CHECKING:
 
 Delta = Dict[str, Any]
 var = computed_var
+config = get_config()
 
 
 # If the state is this large, it's considered a performance issue.
@@ -2202,7 +2204,14 @@ class StateManager(Base, ABC):
         """
         redis = prerequisites.get_redis()
         if redis is not None:
-            return StateManagerRedis(state=state, redis=redis)
+            # make sure expiration values are obtained only from the config object on creation
+            config = get_config()
+            return StateManagerRedis(
+                state=state,
+                redis=redis,
+                token_expiration=config.redis_token_expiration,
+                lock_expiration=config.redis_lock_expiration,
+            )
         return StateManagerMemory(state=state)
 
     @abstractmethod
@@ -2333,10 +2342,10 @@ class StateManagerRedis(StateManager):
     redis: Redis
 
     # The token expiration time (s).
-    token_expiration: int = constants.Expiration.TOKEN
+    token_expiration: int = config.redis_token_expiration
 
     # The maximum time to hold a lock (ms).
-    lock_expiration: int = constants.Expiration.LOCK
+    lock_expiration: int = config.redis_lock_expiration
 
     # The keyspace subscription string when redis is waiting for lock to be released
     _redis_notify_keyspace_events: str = (

+ 44 - 0
tests/test_state.py

@@ -7,6 +7,7 @@ import functools
 import json
 import os
 import sys
+from textwrap import dedent
 from typing import Any, Dict, Generator, List, Optional, Union
 from unittest.mock import AsyncMock, Mock
 
@@ -14,6 +15,8 @@ import pytest
 from plotly.graph_objects import Figure
 
 import reflex as rx
+import reflex.config
+from reflex import constants
 from reflex.app import App
 from reflex.base import Base
 from reflex.constants import CompileVars, RouteVar, SocketEvent
@@ -33,6 +36,7 @@ from reflex.state import (
     StateUpdate,
     _substate_key,
 )
+from reflex.testing import chdir
 from reflex.utils import format, prerequisites, types
 from reflex.utils.format import json_dumps
 from reflex.vars import BaseVar, ComputedVar
@@ -2925,3 +2929,43 @@ async def test_setvar(mock_app: rx.App, token: str):
     # Cannot setvar with non-string
     with pytest.raises(ValueError):
         TestState.setvar(42, 42)
+
+
+@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
+@pytest.mark.parametrize(
+    "expiration_kwargs, expected_values",
+    [
+        ({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
+        (
+            {"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
+            (50000, 5600),
+        ),
+        ({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
+    ],
+)
+def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
+    proj_root = tmp_path / "project1"
+    proj_root.mkdir()
+
+    config_items = ",\n    ".join(
+        f"{key} = {value}" for key, value in expiration_kwargs.items()
+    )
+
+    config_string = f"""
+import reflex as rx
+config = rx.Config(
+    app_name="project1",
+    redis_url="redis://localhost:6379",
+    {config_items}
+)
+"""
+    (proj_root / "rxconfig.py").write_text(dedent(config_string))
+
+    with chdir(proj_root):
+        # reload config for each parameter to avoid stale values
+        reflex.config.get_config(reload=True)
+        from reflex.state import State, StateManager
+
+        state_manager = StateManager.create(state=State)
+        assert state_manager.lock_expiration == expected_values[0]  # type: ignore
+        assert state_manager.token_expiration == expected_values[1]  # type: ignore