Răsfoiți Sursa

Validate component children (#1647)

Elijah Ahianyo 1 an în urmă
părinte
comite
217a5806ee

+ 36 - 14
reflex/components/component.py

@@ -66,6 +66,10 @@ class Component(Base, ABC):
 
     # components that cannot be children
     invalid_children: List[str] = []
+
+    # components that are only allowed as children
+    valid_children: List[str] = []
+
     # custom attribute
     custom_attrs: Dict[str, str] = {}
 
@@ -103,9 +107,10 @@ class Component(Base, ABC):
             TypeError: If an invalid prop is passed.
         """
         # Set the id and children initially.
+        children = kwargs.get("children", [])
         initial_kwargs = {
             "id": kwargs.get("id"),
-            "children": kwargs.get("children", []),
+            "children": children,
             **{
                 prop: Var.create(kwargs[prop])
                 for prop in self.get_initial_props()
@@ -114,6 +119,8 @@ class Component(Base, ABC):
         }
         super().__init__(**initial_kwargs)
 
+        self._validate_component_children(children)
+
         # Get the component fields, triggers, and props.
         fields = self.get_fields()
         triggers = self.get_triggers()
@@ -381,6 +388,7 @@ class Component(Base, ABC):
             else Bare.create(contents=Var.create(child, is_string=True))
             for child in children
         ]
+
         return cls(children=children, **props)
 
     def _add_style(self, style):
@@ -435,30 +443,44 @@ class Component(Base, ABC):
             ),
             autofocus=self.autofocus,
         )
-        self._validate_component_children(
-            rendered_dict["name"], rendered_dict["children"]
-        )
         return rendered_dict
 
-    def _validate_component_children(self, comp_name: str, children: List[Dict]):
+    def _validate_component_children(self, children: List[Component]):
         """Validate the children components.
 
         Args:
-            comp_name: name of the component.
-            children: list of children components.
+            children: The children of the component.
 
-        Raises:
-            ValueError: when an unsupported component is matched.
         """
-        if not self.invalid_children:
+        if not self.invalid_children and not self.valid_children:
             return
-        for child in children:
-            name = child["name"]
-            if name in self.invalid_children:
+
+        comp_name = type(self).__name__
+
+        def validate_invalid_child(child_name):
+            if child_name in self.invalid_children:
                 raise ValueError(
-                    f"The component `{comp_name.lower()}` cannot have `{name.lower()}` as a child component"
+                    f"The component `{comp_name}` cannot have `{child_name}` as a child component"
                 )
 
+        def validate_valid_child(child_name):
+            if child_name not in self.valid_children:
+                valid_child_list = ", ".join(
+                    [f"`{v_child}`" for v_child in self.valid_children]
+                )
+                raise ValueError(
+                    f"The component `{comp_name}` only allows the components: {valid_child_list} as children. Got `{child_name}` instead."
+                )
+
+        for child in children:
+            name = type(child).__name__
+
+            if self.invalid_children:
+                validate_invalid_child(name)
+
+            if self.valid_children:
+                validate_valid_child(name)
+
     def _get_custom_code(self) -> Optional[str]:
         """Get custom code for the component.
 

+ 4 - 0
reflex/components/forms/button.py

@@ -1,4 +1,5 @@
 """A button component."""
+from typing import List
 
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.vars import Var
@@ -42,6 +43,9 @@ class Button(ChakraComponent):
     # The type of button.
     type_: Var[str]
 
+    # Components that are not allowed as children.
+    invalid_children: List[str] = ["Button", "MenuButton"]
+
 
 class ButtonGroup(ChakraComponent):
     """A group of buttons."""

+ 4 - 1
reflex/components/overlay/menu.py

@@ -1,6 +1,6 @@
 """Menu components."""
 
-from typing import Set
+from typing import List, Set
 
 from reflex.components.component import Component
 from reflex.components.libs.chakra import ChakraComponent
@@ -100,6 +100,9 @@ class MenuButton(ChakraComponent):
     # The variant of the menu button.
     variant: Var[str]
 
+    # Components that are not allowed as children.
+    invalid_children: List[str] = ["Button", "MenuButton"]
+
     # The tag to use for the menu button.
     as_: Var[str]
 

+ 64 - 6
tests/components/test_component.py

@@ -121,13 +121,47 @@ def component5() -> Type[Component]:
     """
 
     class TestComponent5(Component):
-        tag = "Tag"
+        tag = "RandomComponent"
 
         invalid_children: List[str] = ["Text"]
 
+        valid_children: List[str] = ["Text"]
+
     return TestComponent5
 
 
+@pytest.fixture
+def component6() -> Type[Component]:
+    """A test component.
+
+    Returns:
+        A test component.
+    """
+
+    class TestComponent6(Component):
+        tag = "RandomComponent"
+
+        invalid_children: List[str] = ["Text"]
+
+    return TestComponent6
+
+
+@pytest.fixture
+def component7() -> Type[Component]:
+    """A test component.
+
+    Returns:
+        A test component.
+    """
+
+    class TestComponent7(Component):
+        tag = "RandomComponent"
+
+        valid_children: List[str] = ["Text"]
+
+    return TestComponent7
+
+
 @pytest.fixture
 def on_click1() -> EventHandler:
     """A sample on click function.
@@ -461,16 +495,40 @@ def test_get_hooks_nested2(component3, component4):
     )
 
 
-def test_unsupported_child_components(component5):
-    """Test that a value error is raised when an unsupported component is provided as a child.
+@pytest.mark.parametrize("fixture", ["component5", "component6"])
+def test_unsupported_child_components(fixture, request):
+    """Test that a value error is raised when an unsupported component (a child component found in the
+    component's invalid children list) is provided as a child.
+
+    Args:
+        fixture: the test component as a fixture.
+        request: Pytest request.
+    """
+    component = request.getfixturevalue(fixture)
+    with pytest.raises(ValueError) as err:
+        comp = component.create(rx.text("testing component"))
+        comp.render()
+    assert (
+        err.value.args[0]
+        == f"The component `{component.__name__}` cannot have `Text` as a child component"
+    )
+
+
+@pytest.mark.parametrize("fixture", ["component5", "component7"])
+def test_component_with_only_valid_children(fixture, request):
+    """Test that a value error is raised when an unsupported component (a child component not found in the
+    component's valid children list) is provided as a child.
 
     Args:
-        component5: the test component
+        fixture: the test component as a fixture.
+        request: Pytest request.
     """
+    component = request.getfixturevalue(fixture)
     with pytest.raises(ValueError) as err:
-        comp = component5.create(rx.text("testing component"))
+        comp = component.create(rx.box("testing component"))
         comp.render()
     assert (
         err.value.args[0]
-        == f"The component `tag` cannot have `text` as a child component"
+        == f"The component `{component.__name__}` only allows the components: `Text` as children. "
+        f"Got `Box` instead."
     )