Bladeren bron

Fix custom components special props (#1956)

Nikhil Rao 1 jaar geleden
bovenliggende
commit
7019708638
2 gewijzigde bestanden met toevoegingen van 34 en 11 verwijderingen
  1. 12 7
      reflex/components/component.py
  2. 22 4
      reflex/components/typography/markdown.py

+ 12 - 7
reflex/components/component.py

@@ -540,7 +540,7 @@ class Component(Base, ABC):
             if self.valid_children:
             if self.valid_children:
                 validate_valid_child(name)
                 validate_valid_child(name)
 
 
-    def _get_custom_code(self) -> Optional[str]:
+    def _get_custom_code(self) -> str | None:
         """Get custom code for the component.
         """Get custom code for the component.
 
 
         Returns:
         Returns:
@@ -569,7 +569,7 @@ class Component(Base, ABC):
         # Return the code.
         # Return the code.
         return code
         return code
 
 
-    def _get_dynamic_imports(self) -> Optional[str]:
+    def _get_dynamic_imports(self) -> str | None:
         """Get dynamic import for the component.
         """Get dynamic import for the component.
 
 
         Returns:
         Returns:
@@ -667,7 +667,7 @@ class Component(Base, ABC):
             if hook
             if hook
         )
         )
 
 
-    def _get_hooks(self) -> Optional[str]:
+    def _get_hooks(self) -> str | None:
         """Get the React hooks for this component.
         """Get the React hooks for this component.
 
 
         Downstream components should override this method to add their own hooks.
         Downstream components should override this method to add their own hooks.
@@ -697,7 +697,7 @@ class Component(Base, ABC):
 
 
         return code
         return code
 
 
-    def get_ref(self) -> Optional[str]:
+    def get_ref(self) -> str | None:
         """Get the name of the ref for the component.
         """Get the name of the ref for the component.
 
 
         Returns:
         Returns:
@@ -723,7 +723,7 @@ class Component(Base, ABC):
         return refs
         return refs
 
 
     def get_custom_components(
     def get_custom_components(
-        self, seen: Optional[Set[str]] = None
+        self, seen: set[str] | None = None
     ) -> Set[CustomComponent]:
     ) -> Set[CustomComponent]:
         """Get all the custom components used by the component.
         """Get all the custom components used by the component.
 
 
@@ -846,7 +846,7 @@ class CustomComponent(Component):
         return set()
         return set()
 
 
     def get_custom_components(
     def get_custom_components(
-        self, seen: Optional[Set[str]] = None
+        self, seen: set[str] | None = None
     ) -> Set[CustomComponent]:
     ) -> Set[CustomComponent]:
         """Get all the custom components used by the component.
         """Get all the custom components used by the component.
 
 
@@ -875,7 +875,10 @@ class CustomComponent(Component):
         Returns:
         Returns:
             The tag to render.
             The tag to render.
         """
         """
-        return Tag(name=self.tag).add_props(**self.props)
+        return Tag(
+            name=self.tag if not self.alias else self.alias,
+            special_props=self.special_props,
+        ).add_props(**self.props)
 
 
     def get_prop_vars(self) -> List[BaseVar]:
     def get_prop_vars(self) -> List[BaseVar]:
         """Get the prop vars.
         """Get the prop vars.
@@ -914,6 +917,8 @@ def custom_component(
 
 
     @wraps(component_fn)
     @wraps(component_fn)
     def wrapper(*children, **props) -> CustomComponent:
     def wrapper(*children, **props) -> CustomComponent:
+        # Remove the children from the props.
+        props.pop("children", None)
         return CustomComponent(component_fn=component_fn, children=children, **props)
         return CustomComponent(component_fn=component_fn, children=children, **props)
 
 
     return wrapper
     return wrapper

+ 22 - 4
reflex/components/typography/markdown.py

@@ -6,7 +6,7 @@ import textwrap
 from typing import Any, Callable, Dict, Union
 from typing import Any, Callable, Dict, Union
 
 
 from reflex.compiler import utils
 from reflex.compiler import utils
-from reflex.components.component import Component
+from reflex.components.component import Component, CustomComponent
 from reflex.components.datadisplay.list import ListItem, OrderedList, UnorderedList
 from reflex.components.datadisplay.list import ListItem, OrderedList, UnorderedList
 from reflex.components.navigation import Link
 from reflex.components.navigation import Link
 from reflex.components.tags.tag import Tag
 from reflex.components.tags.tag import Tag
@@ -19,6 +19,7 @@ from reflex.vars import ImportVar, Var
 # Special vars used in the component map.
 # Special vars used in the component map.
 _CHILDREN = Var.create_safe("children", is_local=False)
 _CHILDREN = Var.create_safe("children", is_local=False)
 _PROPS = Var.create_safe("...props", is_local=False)
 _PROPS = Var.create_safe("...props", is_local=False)
+_MOCK_ARG = Var.create_safe("")
 
 
 # Special remark plugins.
 # Special remark plugins.
 _REMARK_MATH = Var.create_safe("remarkMath", is_local=False)
 _REMARK_MATH = Var.create_safe("remarkMath", is_local=False)
@@ -122,6 +123,25 @@ class Markdown(Component):
         # Create the component.
         # Create the component.
         return super().create(src, component_map=component_map, **props)
         return super().create(src, component_map=component_map, **props)
 
 
+    def get_custom_components(
+        self, seen: set[str] | None = None
+    ) -> set[CustomComponent]:
+        """Get all the custom components used by the component.
+
+        Args:
+            seen: The tags of the components that have already been seen.
+
+        Returns:
+            The set of custom components.
+        """
+        custom_components = super().get_custom_components(seen=seen)
+
+        # Get the custom components for each tag.
+        for component in self.component_map.values():
+            custom_components |= component(_MOCK_ARG).get_custom_components(seen=seen)
+
+        return custom_components
+
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
         # Import here to avoid circular imports.
         # Import here to avoid circular imports.
         from reflex.components.datadisplay.code import Code, CodeBlock
         from reflex.components.datadisplay.code import Code, CodeBlock
@@ -145,9 +165,7 @@ class Markdown(Component):
 
 
         # Get the imports for each component.
         # Get the imports for each component.
         for component in self.component_map.values():
         for component in self.component_map.values():
-            imports = utils.merge_imports(
-                imports, component(Var.create("")).get_imports()
-            )
+            imports = utils.merge_imports(imports, component(_MOCK_ARG).get_imports())
 
 
         # Get the imports for the code components.
         # Get the imports for the code components.
         imports = utils.merge_imports(
         imports = utils.merge_imports(