Browse Source

use sync redis client to sanity check (#2679)

Martin Xu 1 year ago
parent
commit
a3be76fb75
3 changed files with 44 additions and 5 deletions
  1. 1 0
      reflex/state.py
  2. 31 4
      reflex/utils/prerequisites.py
  3. 12 1
      reflex/utils/processes.py

+ 1 - 0
reflex/state.py

@@ -1,4 +1,5 @@
 """Define the reflex state specification."""
+
 from __future__ import annotations
 
 import asyncio

+ 31 - 4
reflex/utils/prerequisites.py

@@ -24,6 +24,7 @@ import pkg_resources
 import typer
 from alembic.util.exc import CommandError
 from packaging import version
+from redis import Redis as RedisSync
 from redis.asyncio import Redis
 
 import reflex
@@ -189,16 +190,42 @@ def get_compiled_app(reload: bool = False) -> ModuleType:
 
 
 def get_redis() -> Redis | None:
-    """Get the redis client.
+    """Get the asynchronous redis client.
 
     Returns:
-        The redis client.
+        The asynchronous redis client.
+    """
+    if isinstance((redis_url_or_options := parse_redis_url()), str):
+        return Redis.from_url(redis_url_or_options)
+    elif isinstance(redis_url_or_options, dict):
+        return Redis(**redis_url_or_options)
+    return None
+
+
+def get_redis_sync() -> RedisSync | None:
+    """Get the synchronous redis client.
+
+    Returns:
+        The synchronous redis client.
+    """
+    if isinstance((redis_url_or_options := parse_redis_url()), str):
+        return RedisSync.from_url(redis_url_or_options)
+    elif isinstance(redis_url_or_options, dict):
+        return RedisSync(**redis_url_or_options)
+    return None
+
+
+def parse_redis_url() -> str | dict | None:
+    """Parse the REDIS_URL in config if applicable.
+
+    Returns:
+        If redis-py syntax, return the URL as it is. Otherwise, return the host/port/db as a dict.
     """
     config = get_config()
     if not config.redis_url:
         return None
     if config.redis_url.startswith(("redis://", "rediss://", "unix://")):
-        return Redis.from_url(config.redis_url)
+        return config.redis_url
     console.deprecate(
         feature_name="host[:port] style redis urls",
         reason="redis-py url syntax is now being used",
@@ -209,7 +236,7 @@ def get_redis() -> Redis | None:
     if not has_port:
         redis_port = 6379
     console.info(f"Using redis at {config.redis_url}")
-    return Redis(host=redis_url, port=int(redis_port), db=0)
+    return dict(host=redis_url, port=int(redis_port), db=0)
 
 
 def get_production_backend_url() -> str:

+ 12 - 1
reflex/utils/processes.py

@@ -12,6 +12,7 @@ from typing import Callable, Generator, List, Optional, Tuple, Union
 
 import psutil
 import typer
+from redis.exceptions import RedisError
 
 from reflex.utils import console, path_ops, prerequisites
 
@@ -28,10 +29,20 @@ def kill(pid):
 def get_num_workers() -> int:
     """Get the number of backend worker processes.
 
+    Raises:
+        Exit: If unable to connect to Redis.
+
     Returns:
         The number of backend worker processes.
     """
-    return 1 if prerequisites.get_redis() is None else (os.cpu_count() or 1) * 2 + 1
+    if (redis_client := prerequisites.get_redis_sync()) is None:
+        return 1
+    try:
+        redis_client.ping()
+    except RedisError as re:
+        console.error(f"Unable to connect to Redis: {re}")
+        raise typer.Exit(1) from re
+    return (os.cpu_count() or 1) * 2 + 1
 
 
 def get_process_on_port(port) -> Optional[psutil.Process]: