Sfoglia il codice sorgente

[ENG-4100]Throw warnings when Redis lock is held for more than the allowed threshold (#4522)

* Throw warnings when Redis lock is held for more than the allowed threshold

* initial tests

* fix tests and address comments

* fix tests fr, and use pydantic validators

* darglint fix

* increase lock expiration in tests to 2500

* remove print statement

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
Elijah Ahianyo 5 mesi fa
parent
commit
c387f517b6

+ 3 - 0
reflex/config.py

@@ -684,6 +684,9 @@ class Config(Base):
     # Maximum expiration lock time for redis state manager
     # Maximum expiration lock time for redis state manager
     redis_lock_expiration: int = constants.Expiration.LOCK
     redis_lock_expiration: int = constants.Expiration.LOCK
 
 
+    # Maximum lock time before warning for redis state manager.
+    redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD
+
     # Token expiration time for redis state manager
     # Token expiration time for redis state manager
     redis_token_expiration: int = constants.Expiration.TOKEN
     redis_token_expiration: int = constants.Expiration.TOKEN
 
 

+ 2 - 0
reflex/constants/config.py

@@ -29,6 +29,8 @@ class Expiration(SimpleNamespace):
     LOCK = 10000
     LOCK = 10000
     # The PING timeout
     # The PING timeout
     PING = 120
     PING = 120
+    # The maximum time in milliseconds to hold a lock before throwing a warning.
+    LOCK_WARNING_THRESHOLD = 1000
 
 
 
 
 class GitIgnore(SimpleNamespace):
 class GitIgnore(SimpleNamespace):

+ 53 - 0
reflex/state.py

@@ -71,6 +71,11 @@ try:
 except ModuleNotFoundError:
 except ModuleNotFoundError:
     BaseModelV1 = BaseModelV2
     BaseModelV1 = BaseModelV2
 
 
+try:
+    from pydantic.v1 import validator
+except ModuleNotFoundError:
+    from pydantic import validator
+
 import wrapt
 import wrapt
 from redis.asyncio import Redis
 from redis.asyncio import Redis
 from redis.exceptions import ResponseError
 from redis.exceptions import ResponseError
@@ -94,6 +99,7 @@ from reflex.utils.exceptions import (
     DynamicRouteArgShadowsStateVar,
     DynamicRouteArgShadowsStateVar,
     EventHandlerShadowsBuiltInStateMethod,
     EventHandlerShadowsBuiltInStateMethod,
     ImmutableStateError,
     ImmutableStateError,
+    InvalidLockWarningThresholdError,
     InvalidStateManagerMode,
     InvalidStateManagerMode,
     LockExpiredError,
     LockExpiredError,
     ReflexRuntimeError,
     ReflexRuntimeError,
@@ -2834,6 +2840,7 @@ class StateManager(Base, ABC):
                     redis=redis,
                     redis=redis,
                     token_expiration=config.redis_token_expiration,
                     token_expiration=config.redis_token_expiration,
                     lock_expiration=config.redis_lock_expiration,
                     lock_expiration=config.redis_lock_expiration,
+                    lock_warning_threshold=config.redis_lock_warning_threshold,
                 )
                 )
         raise InvalidStateManagerMode(
         raise InvalidStateManagerMode(
             f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
             f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
@@ -3203,6 +3210,15 @@ def _default_lock_expiration() -> int:
     return get_config().redis_lock_expiration
     return get_config().redis_lock_expiration
 
 
 
 
+def _default_lock_warning_threshold() -> int:
+    """Get the default lock warning threshold.
+
+    Returns:
+        The default lock warning threshold.
+    """
+    return get_config().redis_lock_warning_threshold
+
+
 class StateManagerRedis(StateManager):
 class StateManagerRedis(StateManager):
     """A state manager that stores states in redis."""
     """A state manager that stores states in redis."""
 
 
@@ -3215,6 +3231,11 @@ class StateManagerRedis(StateManager):
     # The maximum time to hold a lock (ms).
     # The maximum time to hold a lock (ms).
     lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
     lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
 
 
+    # The maximum time to hold a lock (ms) before warning.
+    lock_warning_threshold: int = pydantic.Field(
+        default_factory=_default_lock_warning_threshold
+    )
+
     # The keyspace subscription string when redis is waiting for lock to be released
     # The keyspace subscription string when redis is waiting for lock to be released
     _redis_notify_keyspace_events: str = (
     _redis_notify_keyspace_events: str = (
         "K"  # Enable keyspace notifications (target a particular key)
         "K"  # Enable keyspace notifications (target a particular key)
@@ -3402,6 +3423,17 @@ class StateManagerRedis(StateManager):
                 f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
                 f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
                 "or use `@rx.event(background=True)` decorator for long-running tasks."
                 "or use `@rx.event(background=True)` decorator for long-running tasks."
             )
             )
+        elif lock_id is not None:
+            time_taken = self.lock_expiration / 1000 - (
+                await self.redis.ttl(self._lock_key(token))
+            )
+            if time_taken > self.lock_warning_threshold / 1000:
+                console.warn(
+                    f"Lock for token {token} was held too long {time_taken=}s, "
+                    f"use `@rx.event(background=True)` decorator for long-running tasks.",
+                    dedupe=True,
+                )
+
         client_token, substate_name = _split_substate_key(token)
         client_token, substate_name = _split_substate_key(token)
         # If the substate name on the token doesn't match the instance name, it cannot have a parent.
         # If the substate name on the token doesn't match the instance name, it cannot have a parent.
         if state.parent_state is not None and state.get_full_name() != substate_name:
         if state.parent_state is not None and state.get_full_name() != substate_name:
@@ -3451,6 +3483,27 @@ class StateManagerRedis(StateManager):
             yield state
             yield state
             await self.set_state(token, state, lock_id)
             await self.set_state(token, state, lock_id)
 
 
+    @validator("lock_warning_threshold")
+    @classmethod
+    def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
+        """Validate the lock warning threshold.
+
+        Args:
+            lock_warning_threshold: The lock warning threshold.
+            values: The validated attributes.
+
+        Returns:
+            The lock warning threshold.
+
+        Raises:
+            InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
+        """
+        if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
+            raise InvalidLockWarningThresholdError(
+                f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
+            )
+        return lock_warning_threshold
+
     @staticmethod
     @staticmethod
     def _lock_key(token: str) -> bytes:
     def _lock_key(token: str) -> bytes:
         """Get the redis key for a token's lock.
         """Get the redis key for a token's lock.

+ 60 - 6
reflex/utils/console.py

@@ -20,6 +20,24 @@ _EMITTED_DEPRECATION_WARNINGS = set()
 # Info messages which have been printed.
 # Info messages which have been printed.
 _EMITTED_INFO = set()
 _EMITTED_INFO = set()
 
 
+# Warnings which have been printed.
+_EMIITED_WARNINGS = set()
+
+# Errors which have been printed.
+_EMITTED_ERRORS = set()
+
+# Success messages which have been printed.
+_EMITTED_SUCCESS = set()
+
+# Debug messages which have been printed.
+_EMITTED_DEBUG = set()
+
+# Logs which have been printed.
+_EMITTED_LOGS = set()
+
+# Prints which have been printed.
+_EMITTED_PRINTS = set()
+
 
 
 def set_log_level(log_level: LogLevel):
 def set_log_level(log_level: LogLevel):
     """Set the log level.
     """Set the log level.
@@ -55,25 +73,37 @@ def is_debug() -> bool:
     return _LOG_LEVEL <= LogLevel.DEBUG
     return _LOG_LEVEL <= LogLevel.DEBUG
 
 
 
 
-def print(msg: str, **kwargs):
+def print(msg: str, dedupe: bool = False, **kwargs):
     """Print a message.
     """Print a message.
 
 
     Args:
     Args:
         msg: The message to print.
         msg: The message to print.
+        dedupe: If True, suppress multiple console logs of print message.
         kwargs: Keyword arguments to pass to the print function.
         kwargs: Keyword arguments to pass to the print function.
     """
     """
+    if dedupe:
+        if msg in _EMITTED_PRINTS:
+            return
+        else:
+            _EMITTED_PRINTS.add(msg)
     _console.print(msg, **kwargs)
     _console.print(msg, **kwargs)
 
 
 
 
-def debug(msg: str, **kwargs):
+def debug(msg: str, dedupe: bool = False, **kwargs):
     """Print a debug message.
     """Print a debug message.
 
 
     Args:
     Args:
         msg: The debug message.
         msg: The debug message.
+        dedupe: If True, suppress multiple console logs of debug message.
         kwargs: Keyword arguments to pass to the print function.
         kwargs: Keyword arguments to pass to the print function.
     """
     """
     if is_debug():
     if is_debug():
         msg_ = f"[purple]Debug: {msg}[/purple]"
         msg_ = f"[purple]Debug: {msg}[/purple]"
+        if dedupe:
+            if msg_ in _EMITTED_DEBUG:
+                return
+            else:
+                _EMITTED_DEBUG.add(msg_)
         if progress := kwargs.pop("progress", None):
         if progress := kwargs.pop("progress", None):
             progress.console.print(msg_, **kwargs)
             progress.console.print(msg_, **kwargs)
         else:
         else:
@@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs):
         print(f"[cyan]Info: {msg}[/cyan]", **kwargs)
         print(f"[cyan]Info: {msg}[/cyan]", **kwargs)
 
 
 
 
-def success(msg: str, **kwargs):
+def success(msg: str, dedupe: bool = False, **kwargs):
     """Print a success message.
     """Print a success message.
 
 
     Args:
     Args:
         msg: The success message.
         msg: The success message.
+        dedupe: If True, suppress multiple console logs of success message.
         kwargs: Keyword arguments to pass to the print function.
         kwargs: Keyword arguments to pass to the print function.
     """
     """
     if _LOG_LEVEL <= LogLevel.INFO:
     if _LOG_LEVEL <= LogLevel.INFO:
+        if dedupe:
+            if msg in _EMITTED_SUCCESS:
+                return
+            else:
+                _EMITTED_SUCCESS.add(msg)
         print(f"[green]Success: {msg}[/green]", **kwargs)
         print(f"[green]Success: {msg}[/green]", **kwargs)
 
 
 
 
-def log(msg: str, **kwargs):
+def log(msg: str, dedupe: bool = False, **kwargs):
     """Takes a string and logs it to the console.
     """Takes a string and logs it to the console.
 
 
     Args:
     Args:
         msg: The message to log.
         msg: The message to log.
+        dedupe: If True, suppress multiple console logs of log message.
         kwargs: Keyword arguments to pass to the print function.
         kwargs: Keyword arguments to pass to the print function.
     """
     """
     if _LOG_LEVEL <= LogLevel.INFO:
     if _LOG_LEVEL <= LogLevel.INFO:
+        if dedupe:
+            if msg in _EMITTED_LOGS:
+                return
+            else:
+                _EMITTED_LOGS.add(msg)
         _console.log(msg, **kwargs)
         _console.log(msg, **kwargs)
 
 
 
 
@@ -129,14 +171,20 @@ def rule(title: str, **kwargs):
     _console.rule(title, **kwargs)
     _console.rule(title, **kwargs)
 
 
 
 
-def warn(msg: str, **kwargs):
+def warn(msg: str, dedupe: bool = False, **kwargs):
     """Print a warning message.
     """Print a warning message.
 
 
     Args:
     Args:
         msg: The warning message.
         msg: The warning message.
+        dedupe: If True, suppress multiple console logs of warning message.
         kwargs: Keyword arguments to pass to the print function.
         kwargs: Keyword arguments to pass to the print function.
     """
     """
     if _LOG_LEVEL <= LogLevel.WARNING:
     if _LOG_LEVEL <= LogLevel.WARNING:
+        if dedupe:
+            if msg in _EMIITED_WARNINGS:
+                return
+            else:
+                _EMIITED_WARNINGS.add(msg)
         print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
         print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
 
 
 
 
@@ -169,14 +217,20 @@ def deprecate(
             _EMITTED_DEPRECATION_WARNINGS.add(feature_name)
             _EMITTED_DEPRECATION_WARNINGS.add(feature_name)
 
 
 
 
-def error(msg: str, **kwargs):
+def error(msg: str, dedupe: bool = False, **kwargs):
     """Print an error message.
     """Print an error message.
 
 
     Args:
     Args:
         msg: The error message.
         msg: The error message.
+        dedupe: If True, suppress multiple console logs of error message.
         kwargs: Keyword arguments to pass to the print function.
         kwargs: Keyword arguments to pass to the print function.
     """
     """
     if _LOG_LEVEL <= LogLevel.ERROR:
     if _LOG_LEVEL <= LogLevel.ERROR:
+        if dedupe:
+            if msg in _EMITTED_ERRORS:
+                return
+            else:
+                _EMITTED_ERRORS.add(msg)
         print(f"[red]{msg}[/red]", **kwargs)
         print(f"[red]{msg}[/red]", **kwargs)
 
 
 
 

+ 4 - 0
reflex/utils/exceptions.py

@@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
         " Please install it through your system package manager."
         " Please install it through your system package manager."
         + (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
         + (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
     )
     )
+
+
+class InvalidLockWarningThresholdError(ReflexError):
+    """Raised when an invalid lock warning threshold is provided."""

+ 110 - 4
tests/units/test_state.py

@@ -56,6 +56,7 @@ from reflex.state import (
 from reflex.testing import chdir
 from reflex.testing import chdir
 from reflex.utils import format, prerequisites, types
 from reflex.utils import format, prerequisites, types
 from reflex.utils.exceptions import (
 from reflex.utils.exceptions import (
+    InvalidLockWarningThresholdError,
     ReflexRuntimeError,
     ReflexRuntimeError,
     SetUndefinedStateVarError,
     SetUndefinedStateVarError,
     StateSerializationError,
     StateSerializationError,
@@ -67,7 +68,9 @@ from tests.units.states.mutation import MutableSQLAModel, MutableTestState
 from .states import GenState
 from .states import GenState
 
 
 CI = bool(os.environ.get("CI", False))
 CI = bool(os.environ.get("CI", False))
-LOCK_EXPIRATION = 2000 if CI else 300
+LOCK_EXPIRATION = 2500 if CI else 300
+LOCK_WARNING_THRESHOLD = 1000 if CI else 100
+LOCK_WARN_SLEEP = 1.5 if CI else 0.15
 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
 
 
 
 
@@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire(
         substate_token_redis: A token + substate name for looking up in state manager.
         substate_token_redis: A token + substate name for looking up in state manager.
     """
     """
     state_manager_redis.lock_expiration = LOCK_EXPIRATION
     state_manager_redis.lock_expiration = LOCK_EXPIRATION
+    state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
 
 
     async with state_manager_redis.modify_state(substate_token_redis):
     async with state_manager_redis.modify_state(substate_token_redis):
         await asyncio.sleep(0.01)
         await asyncio.sleep(0.01)
@@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
     unexp_num1 = 666
     unexp_num1 = 666
 
 
     state_manager_redis.lock_expiration = LOCK_EXPIRATION
     state_manager_redis.lock_expiration = LOCK_EXPIRATION
+    state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
 
 
     order = []
     order = []
 
 
@@ -1840,6 +1845,39 @@ async def test_state_manager_lock_expire_contend(
     assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
     assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
 
 
 
 
+@pytest.mark.asyncio
+async def test_state_manager_lock_warning_threshold_contend(
+    state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
+):
+    """Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
+
+    Args:
+        state_manager_redis: A state manager instance.
+        token: A token.
+        substate_token_redis: A token + substate name for looking up in state manager.
+        mocker: Pytest mocker object.
+    """
+    console_warn = mocker.patch("reflex.utils.console.warn")
+
+    state_manager_redis.lock_expiration = LOCK_EXPIRATION
+    state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
+
+    order = []
+
+    async def _coro_blocker():
+        async with state_manager_redis.modify_state(substate_token_redis):
+            order.append("blocker")
+            await asyncio.sleep(LOCK_WARN_SLEEP)
+
+    tasks = [
+        asyncio.create_task(_coro_blocker()),
+    ]
+
+    await tasks[0]
+    console_warn.assert_called()
+    assert console_warn.call_count == 7
+
+
 class CopyingAsyncMock(AsyncMock):
 class CopyingAsyncMock(AsyncMock):
     """An AsyncMock, but deepcopy the args and kwargs first."""
     """An AsyncMock, but deepcopy the args and kwargs first."""
 
 
@@ -3253,12 +3291,42 @@ async def test_setvar_async_setter():
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "expiration_kwargs, expected_values",
     "expiration_kwargs, expected_values",
     [
     [
-        ({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
+        (
+            {"redis_lock_expiration": 20000},
+            (
+                20000,
+                constants.Expiration.TOKEN,
+                constants.Expiration.LOCK_WARNING_THRESHOLD,
+            ),
+        ),
         (
         (
             {"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
             {"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
-            (50000, 5600),
+            (50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD),
+        ),
+        (
+            {"redis_token_expiration": 7600},
+            (
+                constants.Expiration.LOCK,
+                7600,
+                constants.Expiration.LOCK_WARNING_THRESHOLD,
+            ),
+        ),
+        (
+            {"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500},
+            (50000, constants.Expiration.TOKEN, 1500),
+        ),
+        (
+            {"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000},
+            (constants.Expiration.LOCK, 5600, 3000),
+        ),
+        (
+            {
+                "redis_lock_expiration": 50000,
+                "redis_token_expiration": 5600,
+                "redis_lock_warning_threshold": 2000,
+            },
+            (50000, 5600, 2000),
         ),
         ),
-        ({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
     ],
     ],
 )
 )
 def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
 def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
@@ -3288,6 +3356,44 @@ config = rx.Config(
         state_manager = StateManager.create(state=State)
         state_manager = StateManager.create(state=State)
         assert state_manager.lock_expiration == expected_values[0]  # type: ignore
         assert state_manager.lock_expiration == expected_values[0]  # type: ignore
         assert state_manager.token_expiration == expected_values[1]  # type: ignore
         assert state_manager.token_expiration == expected_values[1]  # type: ignore
+        assert state_manager.lock_warning_threshold == expected_values[2]  # type: ignore
+
+
+@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
+@pytest.mark.parametrize(
+    "redis_lock_expiration, redis_lock_warning_threshold",
+    [
+        (10000, 10000),
+        (20000, 30000),
+    ],
+)
+def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold(
+    tmp_path, redis_lock_expiration, redis_lock_warning_threshold
+):
+    proj_root = tmp_path / "project1"
+    proj_root.mkdir()
+
+    config_string = f"""
+import reflex as rx
+config = rx.Config(
+    app_name="project1",
+    redis_url="redis://localhost:6379",
+    state_manager_mode="redis",
+    redis_lock_expiration = {redis_lock_expiration},
+    redis_lock_warning_threshold = {redis_lock_warning_threshold},
+)
+    """
+
+    (proj_root / "rxconfig.py").write_text(dedent(config_string))
+
+    with chdir(proj_root):
+        # reload config for each parameter to avoid stale values
+        reflex.config.get_config(reload=True)
+        from reflex.state import State, StateManager
+
+        with pytest.raises(InvalidLockWarningThresholdError):
+            StateManager.create(state=State)
+        del sys.modules[constants.Config.MODULE]
 
 
 
 
 class MixinState(State, mixin=True):
 class MixinState(State, mixin=True):