Bläddra i källkod

allow optional props with None default value (#3179)

benedikt-bartscher 1 år sedan
förälder
incheckning
e31b458a69
3 ändrade filer med 201 tillägg och 7 borttagningar
  1. 20 3
      reflex/components/component.py
  2. 155 3
      tests/components/test_component.py
  3. 26 1
      tests/utils/test_types.py

+ 20 - 3
reflex/components/component.py

@@ -303,6 +303,8 @@ class Component(BaseComponent, ABC):
 
             # Check whether the key is a component prop.
             if types._issubclass(field_type, Var):
+                # Used to store the passed types if var type is a union.
+                passed_types = None
                 try:
                     # Try to create a var from the value.
                     kwargs[key] = Var.create(value)
@@ -327,10 +329,25 @@ class Component(BaseComponent, ABC):
                     # If it is not a valid var, check the base types.
                     passed_type = type(value)
                     expected_type = fields[key].outer_type_
-                if not types._issubclass(passed_type, expected_type):
+                if types.is_union(passed_type):
+                    # We need to check all possible types in the union.
+                    passed_types = (
+                        arg for arg in passed_type.__args__ if arg is not type(None)
+                    )
+                if (
+                    # If the passed var is a union, check if all possible types are valid.
+                    passed_types
+                    and not all(
+                        types._issubclass(pt, expected_type) for pt in passed_types
+                    )
+                ) or (
+                    # Else just check if the passed var type is valid.
+                    not passed_types
+                    and not types._issubclass(passed_type, expected_type)
+                ):
                     value_name = value._var_name if isinstance(value, Var) else value
                     raise TypeError(
-                        f"Invalid var passed for prop {type(self).__name__}.{key}, expected type {expected_type}, got value {value_name} of type {passed_type}."
+                        f"Invalid var passed for prop {type(self).__name__}.{key}, expected type {expected_type}, got value {value_name} of type {passed_types or passed_type}."
                     )
 
             # Check if the key is an event trigger.
@@ -1523,7 +1540,7 @@ class CustomComponent(Component):
 
 
 def custom_component(
-    component_fn: Callable[..., Component]
+    component_fn: Callable[..., Component],
 ) -> Callable[..., CustomComponent]:
     """Create a custom component from a function.
 

+ 155 - 3
tests/components/test_component.py

@@ -1,4 +1,5 @@
-from typing import Any, Dict, List, Type
+from contextlib import nullcontext
+from typing import Any, Dict, List, Optional, Type, Union
 
 import pytest
 
@@ -20,7 +21,7 @@ from reflex.state import BaseState
 from reflex.style import Style
 from reflex.utils import imports
 from reflex.utils.imports import ImportVar
-from reflex.vars import Var, VarData
+from reflex.vars import BaseVar, Var, VarData
 
 
 @pytest.fixture
@@ -52,6 +53,9 @@ def component1() -> Type[Component]:
         # A test number prop.
         number: Var[int]
 
+        # A test string/number prop.
+        text_or_number: Var[Union[int, str]]
+
         def _get_imports(self) -> imports.ImportDict:
             return {"react": [ImportVar(tag="Component")]}
 
@@ -253,6 +257,154 @@ def test_create_component(component1):
     assert c.style == {"color": "white", "textAlign": "center"}
 
 
+@pytest.mark.parametrize(
+    "prop_name,var,expected",
+    [
+        pytest.param(
+            "text",
+            Var.create("hello"),
+            None,
+            id="text",
+        ),
+        pytest.param(
+            "text",
+            BaseVar(_var_name="hello", _var_type=Optional[str]),
+            None,
+            id="text-optional",
+        ),
+        pytest.param(
+            "text",
+            BaseVar(_var_name="hello", _var_type=Union[str, None]),
+            None,
+            id="text-union-str-none",
+        ),
+        pytest.param(
+            "text",
+            BaseVar(_var_name="hello", _var_type=Union[None, str]),
+            None,
+            id="text-union-none-str",
+        ),
+        pytest.param(
+            "text",
+            Var.create(1),
+            TypeError,
+            id="text-int",
+        ),
+        pytest.param(
+            "number",
+            Var.create(1),
+            None,
+            id="number",
+        ),
+        pytest.param(
+            "number",
+            BaseVar(_var_name="1", _var_type=Optional[int]),
+            None,
+            id="number-optional",
+        ),
+        pytest.param(
+            "number",
+            BaseVar(_var_name="1", _var_type=Union[int, None]),
+            None,
+            id="number-union-int-none",
+        ),
+        pytest.param(
+            "number",
+            BaseVar(_var_name="1", _var_type=Union[None, int]),
+            None,
+            id="number-union-none-int",
+        ),
+        pytest.param(
+            "number",
+            Var.create("1"),
+            TypeError,
+            id="number-str",
+        ),
+        pytest.param(
+            "text_or_number",
+            Var.create("hello"),
+            None,
+            id="text_or_number-str",
+        ),
+        pytest.param(
+            "text_or_number",
+            Var.create(1),
+            None,
+            id="text_or_number-int",
+        ),
+        pytest.param(
+            "text_or_number",
+            BaseVar(_var_name="hello", _var_type=Optional[str]),
+            None,
+            id="text_or_number-optional-str",
+        ),
+        pytest.param(
+            "text_or_number",
+            BaseVar(_var_name="hello", _var_type=Union[str, None]),
+            None,
+            id="text_or_number-union-str-none",
+        ),
+        pytest.param(
+            "text_or_number",
+            BaseVar(_var_name="hello", _var_type=Union[None, str]),
+            None,
+            id="text_or_number-union-none-str",
+        ),
+        pytest.param(
+            "text_or_number",
+            BaseVar(_var_name="1", _var_type=Optional[int]),
+            None,
+            id="text_or_number-optional-int",
+        ),
+        pytest.param(
+            "text_or_number",
+            BaseVar(_var_name="1", _var_type=Union[int, None]),
+            None,
+            id="text_or_number-union-int-none",
+        ),
+        pytest.param(
+            "text_or_number",
+            BaseVar(_var_name="1", _var_type=Union[None, int]),
+            None,
+            id="text_or_number-union-none-int",
+        ),
+        pytest.param(
+            "text_or_number",
+            Var.create(1.0),
+            TypeError,
+            id="text_or_number-float",
+        ),
+        pytest.param(
+            "text_or_number",
+            BaseVar(_var_name="hello", _var_type=Optional[Union[str, int]]),
+            None,
+            id="text_or_number-optional-union-str-int",
+        ),
+    ],
+)
+def test_create_component_prop_validation(
+    component1: Type[Component],
+    prop_name: str,
+    var: Union[Var, str, int],
+    expected: Type[Exception],
+):
+    """Test that component props are validated correctly.
+
+    Args:
+        component1: A test component.
+        prop_name: The name of the prop.
+        var: The value of the prop.
+        expected: The expected exception.
+    """
+    ctx = pytest.raises(expected) if expected else nullcontext()
+    kwargs = {prop_name: var}
+    with ctx:
+        c = component1.create(**kwargs)
+        assert isinstance(c, component1)
+        assert c.children == []
+        assert c.style == {}
+
+
 def test_add_style(component1, component2):
     """Test adding a style to a component.
 
@@ -338,7 +490,7 @@ def test_get_props(component1, component2):
         component1: A test component.
         component2: A test component.
     """
-    assert component1.get_props() == {"text", "number"}
+    assert component1.get_props() == {"text", "number", "text_or_number"}
     assert component2.get_props() == {"arr"}
 
 

+ 26 - 1
tests/utils/test_types.py

@@ -1,4 +1,4 @@
-from typing import Literal
+from typing import Any, List, Literal, Tuple, Union
 
 import pytest
 
@@ -20,3 +20,28 @@ def test_validate_literal_error_msg(params, allowed_value_str, value_str):
         err.value.args[0] == f"prop value for {str(params[0])} of the `{params[-1]}` "
         f"component should be one of the following: {allowed_value_str}. Got {value_str} instead"
     )
+
+
+@pytest.mark.parametrize(
+    "cls,cls_check,expected",
+    [
+        (int, Any, True),
+        (Tuple[int], Any, True),
+        (List[int], Any, True),
+        (int, int, True),
+        (int, object, True),
+        (int, Union[int, str], True),
+        (int, Union[str, int], True),
+        (str, Union[str, int], True),
+        (str, Union[int, str], True),
+        (int, Union[str, float, int], True),
+        (int, Union[str, float], False),
+        (int, Union[float, str], False),
+        (int, str, False),
+        (int, List[int], False),
+    ],
+)
+def test_issubclass(
+    cls: types.GenericType, cls_check: types.GenericType, expected: bool
+) -> None:
+    assert types._issubclass(cls, cls_check) == expected