瀏覽代碼

[REF-1919] Valid Children/parents to allow Foreach,Cond,Match and Fragment (#2591)

Elijah Ahianyo 1 年之前
父節點
當前提交
5e9b472d1b
共有 2 個文件被更改,包括 267 次插入19 次删除
  1. 35 19
      reflex/components/component.py
  2. 232 0
      tests/components/test_component.py

+ 35 - 19
reflex/components/component.py

@@ -668,20 +668,43 @@ class Component(BaseComponent, ABC):
             children: The children of the component.
 
         """
-        skip_parentable = all(child._valid_parents == [] for child in children)
-        if not self._invalid_children and not self._valid_children and skip_parentable:
+        no_valid_parents_defined = all(child._valid_parents == [] for child in children)
+        if (
+            not self._invalid_children
+            and not self._valid_children
+            and no_valid_parents_defined
+        ):
             return
 
         comp_name = type(self).__name__
+        allowed_components = ["Fragment", "Foreach", "Cond", "Match"]
+
+        def validate_child(child):
+            child_name = type(child).__name__
+
+            # Iterate through the immediate children of fragment
+            if child_name == "Fragment":
+                for c in child.children:
+                    validate_child(c)
+
+            if child_name == "Cond":
+                validate_child(child.comp1)
+                validate_child(child.comp2)
 
-        def validate_invalid_child(child_name):
-            if child_name in self._invalid_children:
+            if child_name == "Match":
+                for cases in child.match_cases:
+                    validate_child(cases[-1])
+                validate_child(child.default)
+
+            if self._invalid_children and child_name in self._invalid_children:
                 raise ValueError(
                     f"The component `{comp_name}` cannot have `{child_name}` as a child component"
                 )
 
-        def validate_valid_child(child_name):
-            if child_name not in self._valid_children:
+            if self._valid_children and child_name not in [
+                *self._valid_children,
+                *allowed_components,
+            ]:
                 valid_child_list = ", ".join(
                     [f"`{v_child}`" for v_child in self._valid_children]
                 )
@@ -689,26 +712,19 @@ class Component(BaseComponent, ABC):
                     f"The component `{comp_name}` only allows the components: {valid_child_list} as children. Got `{child_name}` instead."
                 )
 
-        def validate_vaild_parent(child_name, valid_parents):
-            if comp_name not in valid_parents:
+            if child._valid_parents and comp_name not in [
+                *child._valid_parents,
+                *allowed_components,
+            ]:
                 valid_parent_list = ", ".join(
-                    [f"`{v_parent}`" for v_parent in valid_parents]
+                    [f"`{v_parent}`" for v_parent in child._valid_parents]
                 )
                 raise ValueError(
                     f"The component `{child_name}` can only be a child of the components: {valid_parent_list}. Got `{comp_name}` instead."
                 )
 
         for child in children:
-            name = type(child).__name__
-
-            if self._invalid_children:
-                validate_invalid_child(name)
-
-            if self._valid_children:
-                validate_valid_child(name)
-
-            if child._valid_parents:
-                validate_vaild_parent(name, child._valid_parents)
+            validate_child(child)
 
     @staticmethod
     def _get_vars_from_event_triggers(

+ 232 - 0
tests/components/test_component.py

@@ -948,3 +948,235 @@ def test_instantiate_all_components():
         component = getattr(rx, component_name)
         if isinstance(component, type) and issubclass(component, Component):
             component.create()
+
+
+class InvalidParentComponent(Component):
+    """Invalid Parent Component."""
+
+    ...
+
+
+class ValidComponent1(Component):
+    """Test valid component."""
+
+    _valid_children = ["ValidComponent2"]
+
+
+class ValidComponent2(Component):
+    """Test valid component."""
+
+    ...
+
+
+class ValidComponent3(Component):
+    """Test valid component."""
+
+    _valid_parents = ["ValidComponent2"]
+
+
+class ValidComponent4(Component):
+    """Test valid component."""
+
+    _invalid_children = ["InvalidComponent"]
+
+
+class InvalidComponent(Component):
+    """Test invalid component."""
+
+    ...
+
+
+valid_component1 = ValidComponent1.create
+valid_component2 = ValidComponent2.create
+invalid_component = InvalidComponent.create
+valid_component3 = ValidComponent3.create
+invalid_parent = InvalidParentComponent.create
+valid_component4 = ValidComponent4.create
+
+
+def test_validate_valid_children():
+    valid_component1(valid_component2())
+    valid_component1(
+        rx.fragment(valid_component2()),
+    )
+    valid_component1(
+        rx.fragment(
+            rx.fragment(
+                rx.fragment(valid_component2()),
+            ),
+        ),
+    )
+
+    valid_component1(
+        rx.cond(  # type: ignore
+            True,
+            rx.fragment(valid_component2()),
+            rx.fragment(
+                rx.foreach(Var.create([1, 2, 3]), lambda x: valid_component2(x))  # type: ignore
+            ),
+        )
+    )
+
+    valid_component1(
+        rx.cond(
+            True,
+            valid_component2(),
+            rx.fragment(
+                rx.match(
+                    "condition",
+                    ("first", valid_component2()),
+                    rx.fragment(valid_component2(rx.text("default"))),
+                )
+            ),
+        )
+    )
+
+    valid_component1(
+        rx.match(
+            "condition",
+            ("first", valid_component2()),
+            ("second", "third", rx.fragment(valid_component2())),
+            (
+                "fourth",
+                rx.cond(True, valid_component2(), rx.fragment(valid_component2())),
+            ),
+            (
+                "fifth",
+                rx.match(
+                    "nested_condition",
+                    ("nested_first", valid_component2()),
+                    rx.fragment(valid_component2()),
+                ),
+                valid_component2(),
+            ),
+        )
+    )
+
+
+def test_validate_valid_parents():
+    valid_component2(valid_component3())
+    valid_component2(
+        rx.fragment(valid_component3()),
+    )
+    valid_component1(
+        rx.fragment(
+            valid_component2(
+                rx.fragment(valid_component3()),
+            ),
+        ),
+    )
+
+    valid_component2(
+        rx.cond(  # type: ignore
+            True,
+            rx.fragment(valid_component3()),
+            rx.fragment(
+                rx.foreach(
+                    Var.create([1, 2, 3]),  # type: ignore
+                    lambda x: valid_component2(valid_component3(x)),
+                )
+            ),
+        )
+    )
+
+    valid_component2(
+        rx.cond(
+            True,
+            valid_component3(),
+            rx.fragment(
+                rx.match(
+                    "condition",
+                    ("first", valid_component3()),
+                    rx.fragment(valid_component3(rx.text("default"))),
+                )
+            ),
+        )
+    )
+
+    valid_component2(
+        rx.match(
+            "condition",
+            ("first", valid_component3()),
+            ("second", "third", rx.fragment(valid_component3())),
+            (
+                "fourth",
+                rx.cond(True, valid_component3(), rx.fragment(valid_component3())),
+            ),
+            (
+                "fifth",
+                rx.match(
+                    "nested_condition",
+                    ("nested_first", valid_component3()),
+                    rx.fragment(valid_component3()),
+                ),
+                valid_component3(),
+            ),
+        )
+    )
+
+
+def test_validate_invalid_children():
+    with pytest.raises(ValueError):
+        valid_component4(invalid_component())
+
+    with pytest.raises(ValueError):
+        valid_component4(
+            rx.fragment(invalid_component()),
+        )
+
+    with pytest.raises(ValueError):
+        valid_component2(
+            rx.fragment(
+                valid_component4(
+                    rx.fragment(invalid_component()),
+                ),
+            ),
+        )
+
+    with pytest.raises(ValueError):
+        valid_component4(
+            rx.cond(  # type: ignore
+                True,
+                rx.fragment(invalid_component()),
+                rx.fragment(
+                    rx.foreach(Var.create([1, 2, 3]), lambda x: invalid_component(x))  # type: ignore
+                ),
+            )
+        )
+
+    with pytest.raises(ValueError):
+        valid_component4(
+            rx.cond(
+                True,
+                invalid_component(),
+                rx.fragment(
+                    rx.match(
+                        "condition",
+                        ("first", invalid_component()),
+                        rx.fragment(invalid_component(rx.text("default"))),
+                    )
+                ),
+            )
+        )
+
+    with pytest.raises(ValueError):
+        valid_component4(
+            rx.match(
+                "condition",
+                ("first", invalid_component()),
+                ("second", "third", rx.fragment(invalid_component())),
+                (
+                    "fourth",
+                    rx.cond(True, invalid_component(), rx.fragment(valid_component2())),
+                ),
+                (
+                    "fifth",
+                    rx.match(
+                        "nested_condition",
+                        ("nested_first", invalid_component()),
+                        rx.fragment(invalid_component()),
+                    ),
+                    invalid_component(),
+                ),
+            )
+        )