Просмотр исходного кода

optimize rx.color to not use validate literal parameters (#5244)

* optimize rx.color to not use validate literal parameters

* fstring issues
Khaleel Al-Adhami 3 недель назад
Родитель
Сommit
c14333515a

+ 35 - 4
reflex/components/core/colors.py

@@ -1,11 +1,22 @@
 """The colors used in Reflex are a wrapper around https://www.radix-ui.com/colors."""
 
-from reflex.constants.colors import Color, ColorType, ShadeType
-from reflex.utils.types import validate_parameter_literals
+from reflex.constants.base import REFLEX_VAR_OPENING_TAG
+from reflex.constants.colors import (
+    COLORS,
+    MAX_SHADE_VALUE,
+    MIN_SHADE_VALUE,
+    Color,
+    ColorType,
+    ShadeType,
+)
+from reflex.vars.base import Var
 
 
-@validate_parameter_literals
-def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Color:
+def color(
+    color: ColorType | Var[str],
+    shade: ShadeType | Var[int] = 7,
+    alpha: bool | Var[bool] = False,
+) -> Color:
     """Create a color object.
 
     Args:
@@ -15,5 +26,25 @@ def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Color:
 
     Returns:
         The color object.
+
+    Raises:
+        ValueError: If the color, shade, or alpha are not valid.
     """
+    if isinstance(color, str):
+        if color not in COLORS and REFLEX_VAR_OPENING_TAG not in color:
+            raise ValueError(f"Color must be one of {COLORS}, received {color}")
+    elif not isinstance(color, Var):
+        raise ValueError("Color must be a string or a Var")
+
+    if isinstance(shade, int):
+        if shade < MIN_SHADE_VALUE or shade > MAX_SHADE_VALUE:
+            raise ValueError(
+                f"Shade must be between {MIN_SHADE_VALUE} and {MAX_SHADE_VALUE}"
+            )
+    elif not isinstance(shade, Var):
+        raise ValueError("Shade must be an integer or a Var")
+
+    if not isinstance(alpha, (bool, Var)):
+        raise ValueError("Alpha must be a boolean or a Var")
+
     return Color(color, shade, alpha)

+ 23 - 6
reflex/constants/colors.py

@@ -1,7 +1,12 @@
 """The colors used in Reflex are a wrapper around https://www.radix-ui.com/colors."""
 
+from __future__ import annotations
+
 from dataclasses import dataclass
-from typing import Literal
+from typing import TYPE_CHECKING, Literal, get_args
+
+if TYPE_CHECKING:
+    from reflex.vars import Var
 
 ColorType = Literal[
     "gray",
@@ -40,10 +45,16 @@ ColorType = Literal[
     "white",
 ]
 
+COLORS = frozenset(get_args(ColorType))
+
 ShadeType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
+MIN_SHADE_VALUE = 1
+MAX_SHADE_VALUE = 12
 
 
-def format_color(color: ColorType, shade: ShadeType, alpha: bool) -> str:
+def format_color(
+    color: ColorType | Var[str], shade: ShadeType | Var[int], alpha: bool | Var[bool]
+) -> str:
     """Format a color as a CSS color string.
 
     Args:
@@ -54,7 +65,13 @@ def format_color(color: ColorType, shade: ShadeType, alpha: bool) -> str:
     Returns:
         The formatted color.
     """
-    return f"var(--{color}-{'a' if alpha else ''}{shade})"
+    if isinstance(alpha, bool):
+        return f"var(--{color}-{'a' if alpha else ''}{shade})"
+
+    from reflex.components.core import cond
+
+    alpha_var = cond(alpha, "a", "")
+    return f"var(--{color}-{alpha_var}{shade})"
 
 
 @dataclass
@@ -62,13 +79,13 @@ class Color:
     """A color in the Reflex color palette."""
 
     # The color palette to use
-    color: ColorType
+    color: ColorType | Var[str]
 
     # The shade of the color to use
-    shade: ShadeType = 7
+    shade: ShadeType | Var[int] = 7
 
     # Whether to use the alpha variant of the color
-    alpha: bool = False
+    alpha: bool | Var[bool] = False
 
     def __format__(self, format_spec: str) -> str:
         """Format the color as a CSS color string.

+ 10 - 3
reflex/utils/types.py

@@ -909,12 +909,19 @@ def validate_parameter_literals(func: Callable):
     Returns:
         The wrapper function.
     """
+    console.deprecate(
+        "validate_parameter_literals",
+        reason="Use manual validation instead.",
+        deprecation_version="0.7.11",
+        removal_version="0.8.0",
+        dedupe=True,
+    )
+
+    func_params = list(inspect.signature(func).parameters.items())
+    annotations = {param[0]: param[1].annotation for param in func_params}
 
     @wraps(func)
     def wrapper(*args, **kwargs):
-        func_params = list(inspect.signature(func).parameters.items())
-        annotations = {param[0]: param[1].annotation for param in func_params}
-
         # validate args
         for param, arg in zip(annotations, args, strict=False):
             if annotations[param] is inspect.Parameter.empty:

+ 19 - 12
tests/units/components/core/test_colors.py

@@ -9,10 +9,10 @@ from reflex.vars.base import LiteralVar
 class ColorState(rx.State):
     """Test color state."""
 
-    color: str = "mint"
-    color_part: str = "tom"
-    shade: int = 4
-    alpha: bool = False
+    color: rx.Field[str] = rx.field("mint")
+    color_part: rx.Field[str] = rx.field("tom")
+    shade: rx.Field[int] = rx.field(4)
+    alpha: rx.Field[bool] = rx.field(False)
 
 
 color_state_name = ColorState.get_full_name().replace(".", "__")
@@ -22,6 +22,12 @@ def create_color_var(color):
     return LiteralVar.create(color)
 
 
+color_with_fstring = rx.color(
+    f"{ColorState.color}",  # pyright: ignore [reportArgumentType]
+    ColorState.shade,
+)
+
+
 @pytest.mark.parametrize(
     "color, expected, expected_type",
     [
@@ -41,26 +47,27 @@ def create_color_var(color):
             Color,
         ),
         (
-            create_color_var(rx.color(f"{ColorState.color}", f"{ColorState.shade}")),
-            f'("var(--"+{color_state_name!s}.color+"-"+{color_state_name!s}.shade+")")',
+            create_color_var(color_with_fstring),
+            f'("var(--"+{color_state_name!s}.color+"-"+(((__to_string) => __to_string.toString())({color_state_name!s}.shade))+")")',
             Color,
         ),
         (
             create_color_var(
-                rx.color(f"{ColorState.color_part}ato", f"{ColorState.shade}")
+                rx.color(
+                    f"{ColorState.color_part}ato",  # pyright: ignore [reportArgumentType]
+                    ColorState.shade,
+                )
             ),
-            f'("var(--"+({color_state_name!s}.color_part+"ato")+"-"+{color_state_name!s}.shade+")")',
+            f'("var(--"+({color_state_name!s}.color_part+"ato")+"-"+(((__to_string) => __to_string.toString())({color_state_name!s}.shade))+")")',
             Color,
         ),
         (
-            create_color_var(f"{rx.color(ColorState.color, f'{ColorState.shade}')}"),
+            create_color_var(f"{rx.color(ColorState.color, ColorState.shade)}"),
             f'("var(--"+{color_state_name!s}.color+"-"+{color_state_name!s}.shade+")")',
             str,
         ),
         (
-            create_color_var(
-                f"{rx.color(f'{ColorState.color}', f'{ColorState.shade}')}"
-            ),
+            create_color_var(f"{color_with_fstring}"),
             f'("var(--"+{color_state_name!s}.color+"-"+{color_state_name!s}.shade+")")',
             str,
         ),