Ver código fonte

Enable automatic retry on redis errors

ExponentialBackoff 3x retry for BusyLoadingError, ConnectionError, and TimeoutError
Masen Furer 4 meses atrás
pai
commit
c902e6dd45
1 arquivos alterados com 23 adições e 12 exclusões
  1. 23 12
      reflex/utils/prerequisites.py

+ 23 - 12
reflex/utils/prerequisites.py

@@ -21,15 +21,17 @@ import zipfile
 from datetime import datetime
 from datetime import datetime
 from pathlib import Path
 from pathlib import Path
 from types import ModuleType
 from types import ModuleType
-from typing import Callable, List, Optional
+from typing import Any, Callable, List, Optional
 
 
 import httpx
 import httpx
 import typer
 import typer
 from alembic.util.exc import CommandError
 from alembic.util.exc import CommandError
 from packaging import version
 from packaging import version
 from redis import Redis as RedisSync
 from redis import Redis as RedisSync
-from redis import exceptions
 from redis.asyncio import Redis
 from redis.asyncio import Redis
+from redis.backoff import ExponentialBackoff
+from redis.exceptions import BusyLoadingError, ConnectionError, RedisError, TimeoutError
+from redis.retry import Retry
 
 
 from reflex import constants, model
 from reflex import constants, model
 from reflex.compiler import templates
 from reflex.compiler import templates
@@ -327,16 +329,24 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
     return app_module
     return app_module
 
 
 
 
+def _get_common_redis_kwargs() -> dict[str, Any]:
+    return {
+        "retry": Retry(ExponentialBackoff(), 3),
+        "retry_on_error": [BusyLoadingError, ConnectionError, TimeoutError],
+    }
+
+
 def get_redis() -> Redis | None:
 def get_redis() -> Redis | None:
     """Get the asynchronous redis client.
     """Get the asynchronous redis client.
 
 
     Returns:
     Returns:
         The asynchronous redis client.
         The asynchronous redis client.
     """
     """
-    if isinstance((redis_url_or_options := parse_redis_url()), str):
-        return Redis.from_url(redis_url_or_options)
-    elif isinstance(redis_url_or_options, dict):
-        return Redis(**redis_url_or_options)
+    if (redis_url := parse_redis_url()) is not None:
+        return Redis.from_url(
+            redis_url,
+            **_get_common_redis_kwargs(),
+        )
     return None
     return None
 
 
 
 
@@ -346,14 +356,15 @@ def get_redis_sync() -> RedisSync | None:
     Returns:
     Returns:
         The synchronous redis client.
         The synchronous redis client.
     """
     """
-    if isinstance((redis_url_or_options := parse_redis_url()), str):
-        return RedisSync.from_url(redis_url_or_options)
-    elif isinstance(redis_url_or_options, dict):
-        return RedisSync(**redis_url_or_options)
+    if (redis_url := parse_redis_url()) is not None:
+        return RedisSync.from_url(
+            redis_url,
+            **_get_common_redis_kwargs(),
+        )
     return None
     return None
 
 
 
 
-def parse_redis_url() -> str | dict | None:
+def parse_redis_url() -> str | None:
     """Parse the REDIS_URL in config if applicable.
     """Parse the REDIS_URL in config if applicable.
 
 
     Returns:
     Returns:
@@ -387,7 +398,7 @@ async def get_redis_status() -> dict[str, bool | None]:
             redis_client.ping()
             redis_client.ping()
         else:
         else:
             status = None
             status = None
-    except exceptions.RedisError:
+    except RedisError:
         status = False
         status = False
 
 
     return {"redis": status}
     return {"redis": status}