Jelajahi Sumber

Validate component children (#1647)

Elijah Ahianyo 1 tahun lalu
induk
melakukan
217a5806ee

+ 36 - 14
reflex/components/component.py

@@ -66,6 +66,10 @@ class Component(Base, ABC):
 
 
     # components that cannot be children
     # components that cannot be children
     invalid_children: List[str] = []
     invalid_children: List[str] = []
+
+    # components that are only allowed as children
+    valid_children: List[str] = []
+
     # custom attribute
     # custom attribute
     custom_attrs: Dict[str, str] = {}
     custom_attrs: Dict[str, str] = {}
 
 
@@ -103,9 +107,10 @@ class Component(Base, ABC):
             TypeError: If an invalid prop is passed.
             TypeError: If an invalid prop is passed.
         """
         """
         # Set the id and children initially.
         # Set the id and children initially.
+        children = kwargs.get("children", [])
         initial_kwargs = {
         initial_kwargs = {
             "id": kwargs.get("id"),
             "id": kwargs.get("id"),
-            "children": kwargs.get("children", []),
+            "children": children,
             **{
             **{
                 prop: Var.create(kwargs[prop])
                 prop: Var.create(kwargs[prop])
                 for prop in self.get_initial_props()
                 for prop in self.get_initial_props()
@@ -114,6 +119,8 @@ class Component(Base, ABC):
         }
         }
         super().__init__(**initial_kwargs)
         super().__init__(**initial_kwargs)
 
 
+        self._validate_component_children(children)
+
         # Get the component fields, triggers, and props.
         # Get the component fields, triggers, and props.
         fields = self.get_fields()
         fields = self.get_fields()
         triggers = self.get_triggers()
         triggers = self.get_triggers()
@@ -381,6 +388,7 @@ class Component(Base, ABC):
             else Bare.create(contents=Var.create(child, is_string=True))
             else Bare.create(contents=Var.create(child, is_string=True))
             for child in children
             for child in children
         ]
         ]
+
         return cls(children=children, **props)
         return cls(children=children, **props)
 
 
     def _add_style(self, style):
     def _add_style(self, style):
@@ -435,30 +443,44 @@ class Component(Base, ABC):
             ),
             ),
             autofocus=self.autofocus,
             autofocus=self.autofocus,
         )
         )
-        self._validate_component_children(
-            rendered_dict["name"], rendered_dict["children"]
-        )
         return rendered_dict
         return rendered_dict
 
 
-    def _validate_component_children(self, comp_name: str, children: List[Dict]):
+    def _validate_component_children(self, children: List[Component]):
         """Validate the children components.
         """Validate the children components.
 
 
         Args:
         Args:
-            comp_name: name of the component.
-            children: list of children components.
+            children: The children of the component.
 
 
-        Raises:
-            ValueError: when an unsupported component is matched.
         """
         """
-        if not self.invalid_children:
+        if not self.invalid_children and not self.valid_children:
             return
             return
-        for child in children:
-            name = child["name"]
-            if name in self.invalid_children:
+
+        comp_name = type(self).__name__
+
+        def validate_invalid_child(child_name):
+            if child_name in self.invalid_children:
                 raise ValueError(
                 raise ValueError(
-                    f"The component `{comp_name.lower()}` cannot have `{name.lower()}` as a child component"
+                    f"The component `{comp_name}` cannot have `{child_name}` as a child component"
                 )
                 )
 
 
+        def validate_valid_child(child_name):
+            if child_name not in self.valid_children:
+                valid_child_list = ", ".join(
+                    [f"`{v_child}`" for v_child in self.valid_children]
+                )
+                raise ValueError(
+                    f"The component `{comp_name}` only allows the components: {valid_child_list} as children. Got `{child_name}` instead."
+                )
+
+        for child in children:
+            name = type(child).__name__
+
+            if self.invalid_children:
+                validate_invalid_child(name)
+
+            if self.valid_children:
+                validate_valid_child(name)
+
     def _get_custom_code(self) -> Optional[str]:
     def _get_custom_code(self) -> Optional[str]:
         """Get custom code for the component.
         """Get custom code for the component.
 
 

+ 4 - 0
reflex/components/forms/button.py

@@ -1,4 +1,5 @@
 """A button component."""
 """A button component."""
+from typing import List
 
 
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.vars import Var
 from reflex.vars import Var
@@ -42,6 +43,9 @@ class Button(ChakraComponent):
     # The type of button.
     # The type of button.
     type_: Var[str]
     type_: Var[str]
 
 
+    # Components that are not allowed as children.
+    invalid_children: List[str] = ["Button", "MenuButton"]
+
 
 
 class ButtonGroup(ChakraComponent):
 class ButtonGroup(ChakraComponent):
     """A group of buttons."""
     """A group of buttons."""

+ 4 - 1
reflex/components/overlay/menu.py

@@ -1,6 +1,6 @@
 """Menu components."""
 """Menu components."""
 
 
-from typing import Set
+from typing import List, Set
 
 
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
@@ -100,6 +100,9 @@ class MenuButton(ChakraComponent):
     # The variant of the menu button.
     # The variant of the menu button.
     variant: Var[str]
     variant: Var[str]
 
 
+    # Components that are not allowed as children.
+    invalid_children: List[str] = ["Button", "MenuButton"]
+
     # The tag to use for the menu button.
     # The tag to use for the menu button.
     as_: Var[str]
     as_: Var[str]
 
 

+ 64 - 6
tests/components/test_component.py

@@ -121,13 +121,47 @@ def component5() -> Type[Component]:
     """
     """
 
 
     class TestComponent5(Component):
     class TestComponent5(Component):
-        tag = "Tag"
+        tag = "RandomComponent"
 
 
         invalid_children: List[str] = ["Text"]
         invalid_children: List[str] = ["Text"]
 
 
+        valid_children: List[str] = ["Text"]
+
     return TestComponent5
     return TestComponent5
 
 
 
 
+@pytest.fixture
+def component6() -> Type[Component]:
+    """A test component.
+
+    Returns:
+        A test component.
+    """
+
+    class TestComponent6(Component):
+        tag = "RandomComponent"
+
+        invalid_children: List[str] = ["Text"]
+
+    return TestComponent6
+
+
+@pytest.fixture
+def component7() -> Type[Component]:
+    """A test component.
+
+    Returns:
+        A test component.
+    """
+
+    class TestComponent7(Component):
+        tag = "RandomComponent"
+
+        valid_children: List[str] = ["Text"]
+
+    return TestComponent7
+
+
 @pytest.fixture
 @pytest.fixture
 def on_click1() -> EventHandler:
 def on_click1() -> EventHandler:
     """A sample on click function.
     """A sample on click function.
@@ -461,16 +495,40 @@ def test_get_hooks_nested2(component3, component4):
     )
     )
 
 
 
 
-def test_unsupported_child_components(component5):
-    """Test that a value error is raised when an unsupported component is provided as a child.
+@pytest.mark.parametrize("fixture", ["component5", "component6"])
+def test_unsupported_child_components(fixture, request):
+    """Test that a value error is raised when an unsupported component (a child component found in the
+    component's invalid children list) is provided as a child.
+
+    Args:
+        fixture: the test component as a fixture.
+        request: Pytest request.
+    """
+    component = request.getfixturevalue(fixture)
+    with pytest.raises(ValueError) as err:
+        comp = component.create(rx.text("testing component"))
+        comp.render()
+    assert (
+        err.value.args[0]
+        == f"The component `{component.__name__}` cannot have `Text` as a child component"
+    )
+
+
+@pytest.mark.parametrize("fixture", ["component5", "component7"])
+def test_component_with_only_valid_children(fixture, request):
+    """Test that a value error is raised when an unsupported component (a child component not found in the
+    component's valid children list) is provided as a child.
 
 
     Args:
     Args:
-        component5: the test component
+        fixture: the test component as a fixture.
+        request: Pytest request.
     """
     """
+    component = request.getfixturevalue(fixture)
     with pytest.raises(ValueError) as err:
     with pytest.raises(ValueError) as err:
-        comp = component5.create(rx.text("testing component"))
+        comp = component.create(rx.box("testing component"))
         comp.render()
         comp.render()
     assert (
     assert (
         err.value.args[0]
         err.value.args[0]
-        == f"The component `tag` cannot have `text` as a child component"
+        == f"The component `{component.__name__}` only allows the components: `Text` as children. "
+        f"Got `Box` instead."
     )
     )