Browse Source

[REF-3056] Config knob for redis StateManager expiration times (#3523)

Elijah Ahianyo 10 months ago
parent
commit
f037df0977
3 changed files with 62 additions and 3 deletions
  1. 6 0
      reflex/config.py
  2. 12 3
      reflex/state.py
  3. 44 0
      tests/test_state.py

+ 6 - 0
reflex/config.py

@@ -219,6 +219,12 @@ class Config(Base):
     # Number of gunicorn workers from user
     # Number of gunicorn workers from user
     gunicorn_workers: Optional[int] = None
     gunicorn_workers: Optional[int] = None
 
 
+    # Maximum expiration lock time for redis state manager
+    redis_lock_expiration: int = constants.Expiration.LOCK
+
+    # Token expiration time for redis state manager
+    redis_token_expiration: int = constants.Expiration.TOKEN
+
     # Attributes that were explicitly set by the user.
     # Attributes that were explicitly set by the user.
     _non_default_attributes: Set[str] = pydantic.PrivateAttr(set())
     _non_default_attributes: Set[str] = pydantic.PrivateAttr(set())
 
 

+ 12 - 3
reflex/state.py

@@ -40,6 +40,7 @@ from redis.exceptions import ResponseError
 
 
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
+from reflex.config import get_config
 from reflex.event import (
 from reflex.event import (
     BACKGROUND_TASK_MARKER,
     BACKGROUND_TASK_MARKER,
     Event,
     Event,
@@ -60,6 +61,7 @@ if TYPE_CHECKING:
 
 
 Delta = Dict[str, Any]
 Delta = Dict[str, Any]
 var = computed_var
 var = computed_var
+config = get_config()
 
 
 
 
 # If the state is this large, it's considered a performance issue.
 # If the state is this large, it's considered a performance issue.
@@ -2202,7 +2204,14 @@ class StateManager(Base, ABC):
         """
         """
         redis = prerequisites.get_redis()
         redis = prerequisites.get_redis()
         if redis is not None:
         if redis is not None:
-            return StateManagerRedis(state=state, redis=redis)
+            # make sure expiration values are obtained only from the config object on creation
+            config = get_config()
+            return StateManagerRedis(
+                state=state,
+                redis=redis,
+                token_expiration=config.redis_token_expiration,
+                lock_expiration=config.redis_lock_expiration,
+            )
         return StateManagerMemory(state=state)
         return StateManagerMemory(state=state)
 
 
     @abstractmethod
     @abstractmethod
@@ -2333,10 +2342,10 @@ class StateManagerRedis(StateManager):
     redis: Redis
     redis: Redis
 
 
     # The token expiration time (s).
     # The token expiration time (s).
-    token_expiration: int = constants.Expiration.TOKEN
+    token_expiration: int = config.redis_token_expiration
 
 
     # The maximum time to hold a lock (ms).
     # The maximum time to hold a lock (ms).
-    lock_expiration: int = constants.Expiration.LOCK
+    lock_expiration: int = config.redis_lock_expiration
 
 
     # The keyspace subscription string when redis is waiting for lock to be released
     # The keyspace subscription string when redis is waiting for lock to be released
     _redis_notify_keyspace_events: str = (
     _redis_notify_keyspace_events: str = (

+ 44 - 0
tests/test_state.py

@@ -7,6 +7,7 @@ import functools
 import json
 import json
 import os
 import os
 import sys
 import sys
+from textwrap import dedent
 from typing import Any, Dict, Generator, List, Optional, Union
 from typing import Any, Dict, Generator, List, Optional, Union
 from unittest.mock import AsyncMock, Mock
 from unittest.mock import AsyncMock, Mock
 
 
@@ -14,6 +15,8 @@ import pytest
 from plotly.graph_objects import Figure
 from plotly.graph_objects import Figure
 
 
 import reflex as rx
 import reflex as rx
+import reflex.config
+from reflex import constants
 from reflex.app import App
 from reflex.app import App
 from reflex.base import Base
 from reflex.base import Base
 from reflex.constants import CompileVars, RouteVar, SocketEvent
 from reflex.constants import CompileVars, RouteVar, SocketEvent
@@ -33,6 +36,7 @@ from reflex.state import (
     StateUpdate,
     StateUpdate,
     _substate_key,
     _substate_key,
 )
 )
+from reflex.testing import chdir
 from reflex.utils import format, prerequisites, types
 from reflex.utils import format, prerequisites, types
 from reflex.utils.format import json_dumps
 from reflex.utils.format import json_dumps
 from reflex.vars import BaseVar, ComputedVar
 from reflex.vars import BaseVar, ComputedVar
@@ -2925,3 +2929,43 @@ async def test_setvar(mock_app: rx.App, token: str):
     # Cannot setvar with non-string
     # Cannot setvar with non-string
     with pytest.raises(ValueError):
     with pytest.raises(ValueError):
         TestState.setvar(42, 42)
         TestState.setvar(42, 42)
+
+
+@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
+@pytest.mark.parametrize(
+    "expiration_kwargs, expected_values",
+    [
+        ({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
+        (
+            {"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
+            (50000, 5600),
+        ),
+        ({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
+    ],
+)
+def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
+    proj_root = tmp_path / "project1"
+    proj_root.mkdir()
+
+    config_items = ",\n    ".join(
+        f"{key} = {value}" for key, value in expiration_kwargs.items()
+    )
+
+    config_string = f"""
+import reflex as rx
+config = rx.Config(
+    app_name="project1",
+    redis_url="redis://localhost:6379",
+    {config_items}
+)
+"""
+    (proj_root / "rxconfig.py").write_text(dedent(config_string))
+
+    with chdir(proj_root):
+        # reload config for each parameter to avoid stale values
+        reflex.config.get_config(reload=True)
+        from reflex.state import State, StateManager
+
+        state_manager = StateManager.create(state=State)
+        assert state_manager.lock_expiration == expected_values[0]  # type: ignore
+        assert state_manager.token_expiration == expected_values[1]  # type: ignore