Explorar o código

port enum env var support from #4248 (#4251)

* port enum env var support from #4248

* add some tests for interpret env var functions
benedikt-bartscher hai 6 meses
pai
achega
2e100e38d9
Modificáronse 2 ficheiros con 48 adicións e 4 borrados
  1. 26 0
      reflex/config.py
  2. 22 4
      tests/units/test_config.py

+ 26 - 0
reflex/config.py

@@ -3,7 +3,9 @@
 from __future__ import annotations
 
 import dataclasses
+import enum
 import importlib
+import inspect
 import os
 import sys
 import urllib.parse
@@ -221,6 +223,28 @@ def interpret_path_env(value: str, field_name: str) -> Path:
     return path
 
 
+def interpret_enum_env(value: str, field_type: GenericType, field_name: str) -> Any:
+    """Interpret an enum environment variable value.
+
+    Args:
+        value: The environment variable value.
+        field_type: The field type.
+        field_name: The field name.
+
+    Returns:
+        The interpreted value.
+
+    Raises:
+        EnvironmentVarValueError: If the value is invalid.
+    """
+    try:
+        return field_type(value)
+    except ValueError as ve:
+        raise EnvironmentVarValueError(
+            f"Invalid enum value: {value} for {field_name}"
+        ) from ve
+
+
 def interpret_env_var_value(
     value: str, field_type: GenericType, field_name: str
 ) -> Any:
@@ -252,6 +276,8 @@ def interpret_env_var_value(
         return interpret_int_env(value, field_name)
     elif field_type is Path:
         return interpret_path_env(value, field_name)
+    elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
+        return interpret_enum_env(value, field_type, field_name)
 
     else:
         raise ValueError(

+ 22 - 4
tests/units/test_config.py

@@ -7,8 +7,13 @@ import pytest
 
 import reflex as rx
 import reflex.config
-from reflex.config import environment
-from reflex.constants import Endpoint
+from reflex.config import (
+    environment,
+    interpret_boolean_env,
+    interpret_enum_env,
+    interpret_int_env,
+)
+from reflex.constants import Endpoint, Env
 
 
 def test_requires_app_name():
@@ -208,11 +213,11 @@ def test_replace_defaults(
         assert getattr(c, key) == value
 
 
-def reflex_dir_constant():
+def reflex_dir_constant() -> Path:
     return environment.REFLEX_DIR
 
 
-def test_reflex_dir_env_var(monkeypatch, tmp_path):
+def test_reflex_dir_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
     """Test that the REFLEX_DIR environment variable is used to set the Reflex.DIR constant.
 
     Args:
@@ -224,3 +229,16 @@ def test_reflex_dir_env_var(monkeypatch, tmp_path):
     mp_ctx = multiprocessing.get_context(method="spawn")
     with mp_ctx.Pool(processes=1) as pool:
         assert pool.apply(reflex_dir_constant) == tmp_path
+
+
+def test_interpret_enum_env() -> None:
+    assert interpret_enum_env(Env.PROD.value, Env, "REFLEX_ENV") == Env.PROD
+
+
+def test_interpret_int_env() -> None:
+    assert interpret_int_env("3001", "FRONTEND_PORT") == 3001
+
+
+@pytest.mark.parametrize("value, expected", [("true", True), ("false", False)])
+def test_interpret_bool_env(value: str, expected: bool) -> None:
+    assert interpret_boolean_env(value, "TELEMETRY_ENABLED") == expected