Kaynağa Gözat

/health endpoint for K8 Liveness and Readiness probes (#3855)

* Added API Endpoint

* Added API Endpoint

* Added Unit Tests

* Added Unit Tests

* main

* Apply suggestions from Code Review

* Fix Ruff Formatting

* Update Socket Events

* Async Functions
Samarth Bhadane 9 ay önce
ebeveyn
işleme
59047303c9

+ 35 - 2
reflex/app.py

@@ -33,7 +33,7 @@ from typing import (
 
 from fastapi import FastAPI, HTTPException, Request, UploadFile
 from fastapi.middleware import cors
-from fastapi.responses import StreamingResponse
+from fastapi.responses import JSONResponse, StreamingResponse
 from fastapi.staticfiles import StaticFiles
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
@@ -65,7 +65,7 @@ from reflex.components.core.upload import Upload, get_upload_dir
 from reflex.components.radix import themes
 from reflex.config import get_config
 from reflex.event import Event, EventHandler, EventSpec, window_alert
-from reflex.model import Model
+from reflex.model import Model, get_db_status
 from reflex.page import (
     DECORATED_PAGES,
 )
@@ -377,6 +377,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         """Add default api endpoints (ping)."""
         # To test the server.
         self.api.get(str(constants.Endpoint.PING))(ping)
+        self.api.get(str(constants.Endpoint.HEALTH))(health)
 
     def _add_optional_endpoints(self):
         """Add optional api endpoints (_upload)."""
@@ -1319,6 +1320,38 @@ async def ping() -> str:
     return "pong"
 
 
+async def health() -> JSONResponse:
+    """Health check endpoint to assess the status of the database and Redis services.
+
+    Returns:
+        JSONResponse: A JSON object with the health status:
+            - "status" (bool): Overall health, True if all checks pass.
+            - "db" (bool or str): Database status - True, False, or "NA".
+            - "redis" (bool or str): Redis status - True, False, or "NA".
+    """
+    health_status = {"status": True}
+    status_code = 200
+
+    db_status, redis_status = await asyncio.gather(
+        get_db_status(), prerequisites.get_redis_status()
+    )
+
+    health_status["db"] = db_status
+
+    if redis_status is None:
+        health_status["redis"] = False
+    else:
+        health_status["redis"] = redis_status
+
+    if not health_status["db"] or (
+        not health_status["redis"] and redis_status is not None
+    ):
+        health_status["status"] = False
+        status_code = 503
+
+    return JSONResponse(content=health_status, status_code=status_code)
+
+
 def upload(app: App):
     """Upload a file.
 

+ 1 - 0
reflex/constants/event.py

@@ -11,6 +11,7 @@ class Endpoint(Enum):
     EVENT = "_event"
     UPLOAD = "_upload"
     AUTH_CODESPACE = "auth-codespace"
+    HEALTH = "_health"
 
     def __str__(self) -> str:
         """Get the string representation of the endpoint.

+ 22 - 0
reflex/model.py

@@ -15,6 +15,7 @@ import alembic.runtime.environment
 import alembic.script
 import alembic.util
 import sqlalchemy
+import sqlalchemy.exc
 import sqlalchemy.orm
 
 from reflex import constants
@@ -51,6 +52,27 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
     return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
 
 
+async def get_db_status() -> bool:
+    """Checks the status of the database connection.
+
+    Attempts to connect to the database and execute a simple query to verify connectivity.
+
+    Returns:
+        bool: The status of the database connection:
+            - True: The database is accessible.
+            - False: The database is not accessible.
+    """
+    status = True
+    try:
+        engine = get_engine()
+        with engine.connect() as connection:
+            connection.execute(sqlalchemy.text("SELECT 1"))
+    except sqlalchemy.exc.OperationalError:
+        status = False
+
+    return status
+
+
 SQLModelOrSqlAlchemy = Union[
     Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
 ]

+ 25 - 0
reflex/utils/prerequisites.py

@@ -28,6 +28,7 @@ import typer
 from alembic.util.exc import CommandError
 from packaging import version
 from redis import Redis as RedisSync
+from redis import exceptions
 from redis.asyncio import Redis
 
 from reflex import constants, model
@@ -344,6 +345,30 @@ def parse_redis_url() -> str | dict | None:
     return dict(host=redis_url, port=int(redis_port), db=0)
 
 
+async def get_redis_status() -> bool | None:
+    """Checks the status of the Redis connection.
+
+    Attempts to connect to Redis and send a ping command to verify connectivity.
+
+    Returns:
+        bool or None: The status of the Redis connection:
+            - True: Redis is accessible and responding.
+            - False: Redis is not accessible due to a connection error.
+            - None: Redis not used i.e redis_url is not set in rxconfig.
+    """
+    try:
+        status = True
+        redis_client = get_redis_sync()
+        if redis_client is not None:
+            redis_client.ping()
+        else:
+            status = None
+    except exceptions.RedisError:
+        status = False
+
+    return status
+
+
 def validate_app_name(app_name: str | None = None) -> str:
     """Validate the app name.
 

+ 106 - 0
tests/test_health_endpoint.py

@@ -0,0 +1,106 @@
+import json
+from unittest.mock import MagicMock, Mock
+
+import pytest
+import sqlalchemy
+from redis.exceptions import RedisError
+
+from reflex.app import health
+from reflex.model import get_db_status
+from reflex.utils.prerequisites import get_redis_status
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "mock_redis_client, expected_status",
+    [
+        # Case 1: Redis client is available and responds to ping
+        (Mock(ping=lambda: None), True),
+        # Case 2: Redis client raises RedisError
+        (Mock(ping=lambda: (_ for _ in ()).throw(RedisError)), False),
+        # Case 3: Redis client is not used
+        (None, None),
+    ],
+)
+async def test_get_redis_status(mock_redis_client, expected_status, mocker):
+    # Mock the `get_redis_sync` function to return the mock Redis client
+    mock_get_redis_sync = mocker.patch(
+        "reflex.utils.prerequisites.get_redis_sync", return_value=mock_redis_client
+    )
+
+    # Call the function
+    status = await get_redis_status()
+
+    # Verify the result
+    assert status == expected_status
+    mock_get_redis_sync.assert_called_once()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "mock_engine, execute_side_effect, expected_status",
+    [
+        # Case 1: Database is accessible
+        (MagicMock(), None, True),
+        # Case 2: Database connection error (OperationalError)
+        (
+            MagicMock(),
+            sqlalchemy.exc.OperationalError("error", "error", "error"),
+            False,
+        ),
+    ],
+)
+async def test_get_db_status(mock_engine, execute_side_effect, expected_status, mocker):
+    # Mock get_engine to return the mock_engine
+    mock_get_engine = mocker.patch("reflex.model.get_engine", return_value=mock_engine)
+
+    # Mock the connection and its execute method
+    if mock_engine:
+        mock_connection = mock_engine.connect.return_value.__enter__.return_value
+        if execute_side_effect:
+            # Simulate execute method raising an exception
+            mock_connection.execute.side_effect = execute_side_effect
+        else:
+            # Simulate successful execute call
+            mock_connection.execute.return_value = None
+
+    # Call the function
+    status = await get_db_status()
+
+    # Verify the result
+    assert status == expected_status
+    mock_get_engine.assert_called_once()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "db_status, redis_status, expected_status, expected_code",
+    [
+        # Case 1: Both services are connected
+        (True, True, {"status": True, "db": True, "redis": True}, 200),
+        # Case 2: Database not connected, Redis connected
+        (False, True, {"status": False, "db": False, "redis": True}, 503),
+        # Case 3: Database connected, Redis not connected
+        (True, False, {"status": False, "db": True, "redis": False}, 503),
+        # Case 4: Both services not connected
+        (False, False, {"status": False, "db": False, "redis": False}, 503),
+        # Case 5: Database Connected, Redis not used
+        (True, None, {"status": True, "db": True, "redis": False}, 200),
+    ],
+)
+async def test_health(db_status, redis_status, expected_status, expected_code, mocker):
+    # Mock get_db_status and get_redis_status
+    mocker.patch("reflex.app.get_db_status", return_value=db_status)
+    mocker.patch(
+        "reflex.utils.prerequisites.get_redis_status", return_value=redis_status
+    )
+
+    # Call the async health function
+    response = await health()
+
+    print(json.loads(response.body))
+    print(expected_status)
+
+    # Verify the response content and status code
+    assert response.status_code == expected_code
+    assert json.loads(response.body) == expected_status