Selaa lähdekoodia

Decorator to validate `rx.color` prop fields (#2553)

Elijah Ahianyo 1 vuosi sitten
vanhempi
säilyke
64a90fa6eb
3 muutettua tiedostoa jossa 72 lisäystä ja 17 poistoa
  1. 4 12
      reflex/components/component.py
  2. 3 5
      reflex/components/core/colors.py
  3. 65 0
      reflex/utils/types.py

+ 4 - 12
reflex/components/component.py

@@ -199,7 +199,6 @@ class Component(BaseComponent, ABC):
 
         Raises:
             TypeError: If an invalid prop is passed.
-            ValueError: If a prop value is invalid.
         """
         # Set the id and children initially.
         children = kwargs.get("children", [])
@@ -249,17 +248,10 @@ class Component(BaseComponent, ABC):
                         raise TypeError
 
                     expected_type = fields[key].outer_type_.__args__[0]
-
-                    if (
-                        types.is_literal(expected_type)
-                        and value not in expected_type.__args__
-                    ):
-                        allowed_values = expected_type.__args__
-                        if value not in allowed_values:
-                            raise ValueError(
-                                f"prop value for {key} of the `{type(self).__name__}` component should be one of the following: {','.join(allowed_values)}. Got '{value}' instead"
-                            )
-
+                    # validate literal fields.
+                    types.validate_literal(
+                        key, value, expected_type, type(self).__name__
+                    )
                     # Get the passed type and the var type.
                     passed_type = kwargs[key]._var_type
                     expected_type = (

+ 3 - 5
reflex/components/core/colors.py

@@ -1,13 +1,11 @@
 """The colors used in Reflex are a wrapper around https://www.radix-ui.com/colors."""
 
 from reflex.constants.colors import Color, ColorType, ShadeType
+from reflex.utils.types import validate_parameter_literals
 
 
-def color(
-    color: ColorType,
-    shade: ShadeType = 7,
-    alpha: bool = False,
-) -> Color:
+@validate_parameter_literals
+def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Color:
     """Create a color object.
 
     Args:

+ 65 - 0
reflex/utils/types.py

@@ -3,7 +3,9 @@
 from __future__ import annotations
 
 import contextlib
+import inspect
 import types
+from functools import wraps
 from typing import (
     Any,
     Callable,
@@ -330,6 +332,69 @@ def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
     return type_ in allowed_types
 
 
+def validate_literal(key: str, value: Any, expected_type: Type, comp_name: str):
+    """Check that a value is a valid literal.
+
+    Args:
+        key: The prop name.
+        value: The prop value to validate.
+        expected_type: The expected type(literal type).
+        comp_name: Name of the component.
+
+    Raises:
+        ValueError: When the value is not a valid literal.
+    """
+    from reflex.vars import Var
+
+    if (
+        is_literal(expected_type)
+        and not isinstance(value, Var)  # validating vars is not supported yet.
+        and value not in expected_type.__args__
+    ):
+        allowed_values = expected_type.__args__
+        if value not in allowed_values:
+            value_str = ",".join(
+                [str(v) if not isinstance(v, str) else f"'{v}'" for v in allowed_values]
+            )
+            raise ValueError(
+                f"prop value for {str(key)} of the `{comp_name}` component should be one of the following: {value_str}. Got '{value}' instead"
+            )
+
+
+def validate_parameter_literals(func):
+    """Decorator to check that the arguments passed to a function
+    correspond to the correct function parameter if it (the parameter)
+    is a literal type.
+
+    Args:
+        func: The function to validate.
+
+    Returns:
+        The wrapper function.
+    """
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        func_params = list(inspect.signature(func).parameters.items())
+        annotations = {param[0]: param[1].annotation for param in func_params}
+
+        # validate args
+        for param, arg in zip(annotations.keys(), args):
+            if annotations[param] is inspect.Parameter.empty:
+                continue
+            validate_literal(param, arg, annotations[param], func.__name__)
+
+        # validate kwargs.
+        for key, value in kwargs.items():
+            annotation = annotations.get(key)
+            if not annotation or annotation is inspect.Parameter.empty:
+                continue
+            validate_literal(key, value, annotation, func.__name__)
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
 # Store this here for performance.
 StateBases = get_base_class(StateVar)
 StateIterBases = get_base_class(StateIterVar)