Bläddra i källkod

[ENG-4010]Codeblock cleanup in markdown (#4233)

* Codeblock cleanup in markdown

* Initial approach to getting this working with rx.memo and reflex web

* abstract the map var logic

* the tests are not valid + pyright fix

* darglint fix

* Add unit tests plus mix components

* pyi run

* rebase on main

* fix darglint

* testing different OS

* revert

* This should fix it. Right?

* Fix tests

* minor fn signature fix

* use ArgsFunctionOperation

* use destructured args and pass the tests

* fix remaining unit tests

* fix pyi files

* rebase on main

* move language regex on codeblock to markdown

* fix tests

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
Elijah Ahianyo 6 månader sedan
förälder
incheckning
cd59ab5406

+ 0 - 14
reflex/.templates/jinja/web/pages/custom_component.js.jinja2

@@ -8,20 +8,6 @@
 {% endfor %}
 {% endfor %}
 
 
 export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => {
 export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => {
-{% if component.name == "CodeBlock" and "language" in component.props %}
-    if (language) {
-      (async () => {
-        try {
-          const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${language}`);
-          SyntaxHighlighter.registerLanguage(language, module.default);
-        } catch (error) {
-          console.error(`Error importing language module for ${language}:`, error);
-        }
-      })();
-
-
-    }
-{% endif %}
     {% for hook in component.hooks %}
     {% for hook in component.hooks %}
     {{ hook }}
     {{ hook }}
     {% endfor %}
     {% endfor %}

+ 45 - 37
reflex/components/datadisplay/code.py

@@ -8,13 +8,14 @@ from typing import ClassVar, Dict, Literal, Optional, Union
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.core.cond import color_mode_cond
 from reflex.components.core.cond import color_mode_cond
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
+from reflex.components.markdown.markdown import _LANGUAGE, MarkdownComponentMap
 from reflex.components.radix.themes.components.button import Button
 from reflex.components.radix.themes.components.button import Button
 from reflex.components.radix.themes.layout.box import Box
 from reflex.components.radix.themes.layout.box import Box
 from reflex.constants.colors import Color
 from reflex.constants.colors import Color
 from reflex.event import set_clipboard
 from reflex.event import set_clipboard
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import console, format
 from reflex.utils import console, format
-from reflex.utils.imports import ImportDict, ImportVar
+from reflex.utils.imports import ImportVar
 from reflex.vars.base import LiteralVar, Var, VarData
 from reflex.vars.base import LiteralVar, Var, VarData
 
 
 LiteralCodeLanguage = Literal[
 LiteralCodeLanguage = Literal[
@@ -378,7 +379,7 @@ for theme_name in dir(Theme):
     setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme))
     setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme))
 
 
 
 
-class CodeBlock(Component):
+class CodeBlock(Component, MarkdownComponentMap):
     """A code block."""
     """A code block."""
 
 
     library = "react-syntax-highlighter@15.6.1"
     library = "react-syntax-highlighter@15.6.1"
@@ -417,39 +418,6 @@ class CodeBlock(Component):
     # A custom copy button to override the default one.
     # A custom copy button to override the default one.
     copy_button: Optional[Union[bool, Component]] = None
     copy_button: Optional[Union[bool, Component]] = None
 
 
-    def add_imports(self) -> ImportDict:
-        """Add imports for the CodeBlock component.
-
-        Returns:
-            The import dict.
-        """
-        imports_: ImportDict = {}
-
-        if (
-            self.language is not None
-            and (language_without_quotes := str(self.language).replace('"', ""))
-            in LiteralCodeLanguage.__args__  # type: ignore
-        ):
-            imports_[
-                f"react-syntax-highlighter/dist/cjs/languages/prism/{language_without_quotes}"
-            ] = [
-                ImportVar(
-                    tag=format.to_camel_case(language_without_quotes),
-                    is_default=True,
-                    install=False,
-                )
-            ]
-
-        return imports_
-
-    def _get_custom_code(self) -> Optional[str]:
-        if (
-            self.language is not None
-            and (language_without_quotes := str(self.language).replace('"', ""))
-            in LiteralCodeLanguage.__args__  # type: ignore
-        ):
-            return f"{self.alias}.registerLanguage('{language_without_quotes}', {format.to_camel_case(language_without_quotes)})"
-
     @classmethod
     @classmethod
     def create(
     def create(
         cls,
         cls,
@@ -534,8 +502,8 @@ class CodeBlock(Component):
 
 
         theme = self.theme
         theme = self.theme
 
 
-        out.add_props(style=theme).remove_props("theme", "code").add_props(
-            children=self.code
+        out.add_props(style=theme).remove_props("theme", "code", "language").add_props(
+            children=self.code, language=_LANGUAGE
         )
         )
 
 
         return out
         return out
@@ -543,6 +511,46 @@ class CodeBlock(Component):
     def _exclude_props(self) -> list[str]:
     def _exclude_props(self) -> list[str]:
         return ["can_copy", "copy_button"]
         return ["can_copy", "copy_button"]
 
 
+    @classmethod
+    def _get_language_registration_hook(cls) -> str:
+        """Get the hook to register the language.
+
+        Returns:
+            The hook to register the language.
+        """
+        return f"""
+ if ({str(_LANGUAGE)}) {{
+    (async () => {{
+      try {{
+        const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{{str(_LANGUAGE)}}}`);
+        SyntaxHighlighter.registerLanguage({str(_LANGUAGE)}, module.default);
+      }} catch (error) {{
+        console.error(`Error importing language module for ${{{str(_LANGUAGE)}}}:`, error);
+      }}
+    }})();
+  }}
+"""
+
+    @classmethod
+    def get_component_map_custom_code(cls) -> str:
+        """Get the custom code for the component.
+
+        Returns:
+            The custom code for the component.
+        """
+        return cls._get_language_registration_hook()
+
+    def add_hooks(self) -> list[str | Var]:
+        """Add hooks for the component.
+
+        Returns:
+            The hooks for the component.
+        """
+        return [
+            f"const {str(_LANGUAGE)} = {str(self.language)}",
+            self._get_language_registration_hook(),
+        ]
+
 
 
 class CodeblockNamespace(ComponentNamespace):
 class CodeblockNamespace(ComponentNamespace):
     """Namespace for the CodeBlock component."""
     """Namespace for the CodeBlock component."""

+ 5 - 3
reflex/components/datadisplay/code.pyi

@@ -7,10 +7,10 @@ import dataclasses
 from typing import Any, ClassVar, Dict, Literal, Optional, Union, overload
 from typing import Any, ClassVar, Dict, Literal, Optional, Union, overload
 
 
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.constants.colors import Color
 from reflex.constants.colors import Color
 from reflex.event import BASE_STATE, EventType
 from reflex.event import BASE_STATE, EventType
 from reflex.style import Style
 from reflex.style import Style
-from reflex.utils.imports import ImportDict
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
 LiteralCodeLanguage = Literal[
 LiteralCodeLanguage = Literal[
@@ -349,8 +349,7 @@ for theme_name in dir(Theme):
         continue
         continue
     setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme))
     setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme))
 
 
-class CodeBlock(Component):
-    def add_imports(self) -> ImportDict: ...
+class CodeBlock(Component, MarkdownComponentMap):
     @overload
     @overload
     @classmethod
     @classmethod
     def create(  # type: ignore
     def create(  # type: ignore
@@ -984,6 +983,9 @@ class CodeBlock(Component):
         ...
         ...
 
 
     def add_style(self): ...
     def add_style(self): ...
+    @classmethod
+    def get_component_map_custom_code(cls) -> str: ...
+    def add_hooks(self) -> list[str | Var]: ...
 
 
 class CodeblockNamespace(ComponentNamespace):
 class CodeblockNamespace(ComponentNamespace):
     themes = Theme
     themes = Theme

+ 2 - 1
reflex/components/datadisplay/shiki_code_block.py

@@ -12,6 +12,7 @@ from reflex.components.core.colors import color
 from reflex.components.core.cond import color_mode_cond
 from reflex.components.core.cond import color_mode_cond
 from reflex.components.el.elements.forms import Button
 from reflex.components.el.elements.forms import Button
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.components.props import NoExtrasAllowedProps
 from reflex.components.props import NoExtrasAllowedProps
 from reflex.components.radix.themes.layout.box import Box
 from reflex.components.radix.themes.layout.box import Box
 from reflex.event import run_script, set_clipboard
 from reflex.event import run_script, set_clipboard
@@ -528,7 +529,7 @@ class ShikiJsTransformer(ShikiBaseTransformers):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
 
 
 
 
-class ShikiCodeBlock(Component):
+class ShikiCodeBlock(Component, MarkdownComponentMap):
     """A Code block."""
     """A Code block."""
 
 
     library = "/components/shiki/code"
     library = "/components/shiki/code"

+ 2 - 1
reflex/components/datadisplay/shiki_code_block.pyi

@@ -7,6 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.components.props import NoExtrasAllowedProps
 from reflex.components.props import NoExtrasAllowedProps
 from reflex.event import BASE_STATE, EventType
 from reflex.event import BASE_STATE, EventType
 from reflex.style import Style
 from reflex.style import Style
@@ -350,7 +351,7 @@ class ShikiJsTransformer(ShikiBaseTransformers):
     fns: list[FunctionStringVar]
     fns: list[FunctionStringVar]
     style: Optional[Style]
     style: Optional[Style]
 
 
-class ShikiCodeBlock(Component):
+class ShikiCodeBlock(Component, MarkdownComponentMap):
     @overload
     @overload
     @classmethod
     @classmethod
     def create(  # type: ignore
     def create(  # type: ignore

+ 180 - 49
reflex/components/markdown/markdown.py

@@ -2,25 +2,18 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+import dataclasses
 import textwrap
 import textwrap
 from functools import lru_cache
 from functools import lru_cache
 from hashlib import md5
 from hashlib import md5
-from typing import Any, Callable, Dict, Union
+from typing import Any, Callable, Dict, Sequence, Union
 
 
 from reflex.components.component import Component, CustomComponent
 from reflex.components.component import Component, CustomComponent
-from reflex.components.radix.themes.layout.list import (
-    ListItem,
-    OrderedList,
-    UnorderedList,
-)
-from reflex.components.radix.themes.typography.heading import Heading
-from reflex.components.radix.themes.typography.link import Link
-from reflex.components.radix.themes.typography.text import Text
 from reflex.components.tags.tag import Tag
 from reflex.components.tags.tag import Tag
 from reflex.utils import types
 from reflex.utils import types
 from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.base import LiteralVar, Var
-from reflex.vars.function import ARRAY_ISARRAY
+from reflex.vars.function import ARRAY_ISARRAY, ArgsFunctionOperation, DestructuredArg
 from reflex.vars.number import ternary_operation
 from reflex.vars.number import ternary_operation
 
 
 # Special vars used in the component map.
 # Special vars used in the component map.
@@ -28,6 +21,7 @@ _CHILDREN = Var(_js_expr="children", _var_type=str)
 _PROPS = Var(_js_expr="...props")
 _PROPS = Var(_js_expr="...props")
 _PROPS_IN_TAG = Var(_js_expr="{...props}")
 _PROPS_IN_TAG = Var(_js_expr="{...props}")
 _MOCK_ARG = Var(_js_expr="", _var_type=str)
 _MOCK_ARG = Var(_js_expr="", _var_type=str)
+_LANGUAGE = Var(_js_expr="_language", _var_type=str)
 
 
 # Special remark plugins.
 # Special remark plugins.
 _REMARK_MATH = Var(_js_expr="remarkMath")
 _REMARK_MATH = Var(_js_expr="remarkMath")
@@ -53,7 +47,15 @@ def get_base_component_map() -> dict[str, Callable]:
         The base component map.
         The base component map.
     """
     """
     from reflex.components.datadisplay.code import CodeBlock
     from reflex.components.datadisplay.code import CodeBlock
+    from reflex.components.radix.themes.layout.list import (
+        ListItem,
+        OrderedList,
+        UnorderedList,
+    )
     from reflex.components.radix.themes.typography.code import Code
     from reflex.components.radix.themes.typography.code import Code
+    from reflex.components.radix.themes.typography.heading import Heading
+    from reflex.components.radix.themes.typography.link import Link
+    from reflex.components.radix.themes.typography.text import Text
 
 
     return {
     return {
         "h1": lambda value: Heading.create(value, as_="h1", size="6", margin_y="0.5em"),
         "h1": lambda value: Heading.create(value, as_="h1", size="6", margin_y="0.5em"),
@@ -74,6 +76,67 @@ def get_base_component_map() -> dict[str, Callable]:
     }
     }
 
 
 
 
+@dataclasses.dataclass()
+class MarkdownComponentMap:
+    """Mixin class for handling custom component maps in Markdown components."""
+
+    _explicit_return: bool = dataclasses.field(default=False)
+
+    @classmethod
+    def get_component_map_custom_code(cls) -> str:
+        """Get the custom code for the component map.
+
+        Returns:
+            The custom code for the component map.
+        """
+        return ""
+
+    @classmethod
+    def create_map_fn_var(
+        cls,
+        fn_body: Var | None = None,
+        fn_args: Sequence[str] | None = None,
+        explicit_return: bool | None = None,
+    ) -> Var:
+        """Create a function Var for the component map.
+
+        Args:
+            fn_body: The formatted component as a string.
+            fn_args: The function arguments.
+            explicit_return: Whether to use explicit return syntax.
+
+        Returns:
+            The function Var for the component map.
+        """
+        fn_args = fn_args or cls.get_fn_args()
+        fn_body = fn_body if fn_body is not None else cls.get_fn_body()
+        explicit_return = explicit_return or cls._explicit_return
+
+        return ArgsFunctionOperation.create(
+            args_names=(DestructuredArg(fields=tuple(fn_args)),),
+            return_expr=fn_body,
+            explicit_return=explicit_return,
+        )
+
+    @classmethod
+    def get_fn_args(cls) -> Sequence[str]:
+        """Get the function arguments for the component map.
+
+        Returns:
+            The function arguments as a list of strings.
+        """
+        return ["node", _CHILDREN._js_expr, _PROPS._js_expr]
+
+    @classmethod
+    def get_fn_body(cls) -> Var:
+        """Get the function body for the component map.
+
+        Returns:
+            The function body as a string.
+        """
+        return Var(_js_expr="undefined", _var_type=None)
+
+
 class Markdown(Component):
 class Markdown(Component):
     """A markdown component."""
     """A markdown component."""
 
 
@@ -153,9 +216,6 @@ class Markdown(Component):
         Returns:
         Returns:
             The imports for the markdown component.
             The imports for the markdown component.
         """
         """
-        from reflex.components.datadisplay.code import CodeBlock, Theme
-        from reflex.components.radix.themes.typography.code import Code
-
         return [
         return [
             {
             {
                 "": "katex/dist/katex.min.css",
                 "": "katex/dist/katex.min.css",
@@ -179,10 +239,71 @@ class Markdown(Component):
                 component(_MOCK_ARG)._get_all_imports()  # type: ignore
                 component(_MOCK_ARG)._get_all_imports()  # type: ignore
                 for component in self.component_map.values()
                 for component in self.component_map.values()
             ],
             ],
-            CodeBlock.create(theme=Theme.light)._get_imports(),
-            Code.create()._get_imports(),
         ]
         ]
 
 
+    def _get_tag_map_fn_var(self, tag: str) -> Var:
+        return self._get_map_fn_var_from_children(self.get_component(tag), tag)
+
+    def format_component_map(self) -> dict[str, Var]:
+        """Format the component map for rendering.
+
+        Returns:
+            The formatted component map.
+        """
+        components = {
+            tag: self._get_tag_map_fn_var(tag)
+            for tag in self.component_map
+            if tag not in ("code", "codeblock")
+        }
+
+        # Separate out inline code and code blocks.
+        components["code"] = self._get_inline_code_fn_var()
+
+        return components
+
+    def _get_inline_code_fn_var(self) -> Var:
+        """Get the function variable for inline code.
+
+        This function creates a Var that represents a function to handle
+        both inline code and code blocks in markdown.
+
+        Returns:
+            The Var for inline code.
+        """
+        # Get any custom code from the codeblock and code components.
+        custom_code_list = self._get_map_fn_custom_code_from_children(
+            self.get_component("codeblock")
+        )
+        custom_code_list.extend(
+            self._get_map_fn_custom_code_from_children(self.get_component("code"))
+        )
+
+        codeblock_custom_code = "\n".join(custom_code_list)
+
+        # Format the code to handle inline and block code.
+        formatted_code = f"""
+const match = (className || '').match(/language-(?<lang>.*)/);
+const {str(_LANGUAGE)} = match ? match[1] : '';
+{codeblock_custom_code};
+            return inline ? (
+                {self.format_component("code")}
+            ) : (
+                {self.format_component("codeblock", language=_LANGUAGE)}
+            );
+        """.replace("\n", " ")
+
+        return MarkdownComponentMap.create_map_fn_var(
+            fn_args=(
+                "node",
+                "inline",
+                "className",
+                _CHILDREN._js_expr,
+                _PROPS._js_expr,
+            ),
+            fn_body=Var(_js_expr=formatted_code),
+            explicit_return=True,
+        )
+
     def get_component(self, tag: str, **props) -> Component:
     def get_component(self, tag: str, **props) -> Component:
         """Get the component for a tag and props.
         """Get the component for a tag and props.
 
 
@@ -239,43 +360,53 @@ class Markdown(Component):
         """
         """
         return str(self.get_component(tag, **props)).replace("\n", "")
         return str(self.get_component(tag, **props)).replace("\n", "")
 
 
-    def format_component_map(self) -> dict[str, Var]:
-        """Format the component map for rendering.
+    def _get_map_fn_var_from_children(self, component: Component, tag: str) -> Var:
+        """Create a function Var for the component map for the specified tag.
+
+        Args:
+            component: The component to check for custom code.
+            tag: The tag of the component.
 
 
         Returns:
         Returns:
-            The formatted component map.
+            The function Var for the component map.
         """
         """
-        components = {
-            tag: Var(
-                _js_expr=f"(({{node, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => ({self.format_component(tag)}))"
-            )
-            for tag in self.component_map
-        }
-
-        # Separate out inline code and code blocks.
-        components["code"] = Var(
-            _js_expr=f"""(({{node, inline, className, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => {{
-    const match = (className || '').match(/language-(?<lang>.*)/);
-    const language = match ? match[1] : '';
-    if (language) {{
-    (async () => {{
-      try {{
-        const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{language}}`);
-        SyntaxHighlighter.registerLanguage(language, module.default);
-      }} catch (error) {{
-        console.error(`Error importing language module for ${{language}}:`, error);
-      }}
-    }})();
-  }}
-    return inline ? (
-        {self.format_component("code")}
-    ) : (
-        {self.format_component("codeblock", language=Var(_js_expr="language", _var_type=str))}
-    );
-      }})""".replace("\n", " ")
+        formatted_component = Var(
+            _js_expr=f"({self.format_component(tag)})", _var_type=str
         )
         )
+        if isinstance(component, MarkdownComponentMap):
+            return component.create_map_fn_var(fn_body=formatted_component)
 
 
-        return components
+        # fallback to the default fn Var creation if the component is not a MarkdownComponentMap.
+        return MarkdownComponentMap.create_map_fn_var(fn_body=formatted_component)
+
+    def _get_map_fn_custom_code_from_children(self, component) -> list[str]:
+        """Recursively get markdown custom code from children components.
+
+        Args:
+            component: The component to check for custom code.
+
+        Returns:
+            A list of markdown custom code strings.
+        """
+        custom_code_list = []
+        if isinstance(component, MarkdownComponentMap):
+            custom_code_list.append(component.get_component_map_custom_code())
+
+        # If the component is a custom component(rx.memo), obtain the underlining
+        # component and get the custom code from the children.
+        if isinstance(component, CustomComponent):
+            custom_code_list.extend(
+                self._get_map_fn_custom_code_from_children(
+                    component.component_fn(*component.get_prop_vars())
+                )
+            )
+        elif isinstance(component, Component):
+            for child in component.children:
+                custom_code_list.extend(
+                    self._get_map_fn_custom_code_from_children(child)
+                )
+
+        return custom_code_list
 
 
     @staticmethod
     @staticmethod
     def _component_map_hash(component_map) -> str:
     def _component_map_hash(component_map) -> str:
@@ -288,12 +419,12 @@ class Markdown(Component):
         return f"ComponentMap_{self.component_map_hash}"
         return f"ComponentMap_{self.component_map_hash}"
 
 
     def _get_custom_code(self) -> str | None:
     def _get_custom_code(self) -> str | None:
-        hooks = set()
+        hooks = {}
         for _component in self.component_map.values():
         for _component in self.component_map.values():
             comp = _component(_MOCK_ARG)
             comp = _component(_MOCK_ARG)
             hooks.update(comp._get_all_hooks_internal())
             hooks.update(comp._get_all_hooks_internal())
             hooks.update(comp._get_all_hooks())
             hooks.update(comp._get_all_hooks())
-        formatted_hooks = "\n".join(hooks)
+        formatted_hooks = "\n".join(hooks.keys())
         return f"""
         return f"""
         function {self._get_component_map_name()} () {{
         function {self._get_component_map_name()} () {{
             {formatted_hooks}
             {formatted_hooks}

+ 19 - 2
reflex/components/markdown/markdown.pyi

@@ -3,8 +3,9 @@
 # ------------------- DO NOT EDIT ----------------------
 # ------------------- DO NOT EDIT ----------------------
 # This file was generated by `reflex/utils/pyi_generator.py`!
 # This file was generated by `reflex/utils/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
+import dataclasses
 from functools import lru_cache
 from functools import lru_cache
-from typing import Any, Callable, Dict, Optional, Union, overload
+from typing import Any, Callable, Dict, Optional, Sequence, Union, overload
 
 
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.event import BASE_STATE, EventType
 from reflex.event import BASE_STATE, EventType
@@ -16,6 +17,7 @@ _CHILDREN = Var(_js_expr="children", _var_type=str)
 _PROPS = Var(_js_expr="...props")
 _PROPS = Var(_js_expr="...props")
 _PROPS_IN_TAG = Var(_js_expr="{...props}")
 _PROPS_IN_TAG = Var(_js_expr="{...props}")
 _MOCK_ARG = Var(_js_expr="", _var_type=str)
 _MOCK_ARG = Var(_js_expr="", _var_type=str)
+_LANGUAGE = Var(_js_expr="_language", _var_type=str)
 _REMARK_MATH = Var(_js_expr="remarkMath")
 _REMARK_MATH = Var(_js_expr="remarkMath")
 _REMARK_GFM = Var(_js_expr="remarkGfm")
 _REMARK_GFM = Var(_js_expr="remarkGfm")
 _REMARK_UNWRAP_IMAGES = Var(_js_expr="remarkUnwrapImages")
 _REMARK_UNWRAP_IMAGES = Var(_js_expr="remarkUnwrapImages")
@@ -27,6 +29,21 @@ NO_PROPS_TAGS = ("ul", "ol", "li")
 
 
 @lru_cache
 @lru_cache
 def get_base_component_map() -> dict[str, Callable]: ...
 def get_base_component_map() -> dict[str, Callable]: ...
+@dataclasses.dataclass()
+class MarkdownComponentMap:
+    @classmethod
+    def get_component_map_custom_code(cls) -> str: ...
+    @classmethod
+    def create_map_fn_var(
+        cls,
+        fn_body: Var | None = None,
+        fn_args: Sequence[str] | None = None,
+        explicit_return: bool | None = None,
+    ) -> Var: ...
+    @classmethod
+    def get_fn_args(cls) -> Sequence[str]: ...
+    @classmethod
+    def get_fn_body(cls) -> Var: ...
 
 
 class Markdown(Component):
 class Markdown(Component):
     @overload
     @overload
@@ -82,6 +99,6 @@ class Markdown(Component):
         ...
         ...
 
 
     def add_imports(self) -> ImportDict | list[ImportDict]: ...
     def add_imports(self) -> ImportDict | list[ImportDict]: ...
+    def format_component_map(self) -> dict[str, Var]: ...
     def get_component(self, tag: str, **props) -> Component: ...
     def get_component(self, tag: str, **props) -> Component: ...
     def format_component(self, tag: str, **props) -> str: ...
     def format_component(self, tag: str, **props) -> str: ...
-    def format_component_map(self) -> dict[str, Var]: ...

+ 3 - 2
reflex/components/radix/themes/layout/list.py

@@ -8,6 +8,7 @@ from reflex.components.component import Component, ComponentNamespace
 from reflex.components.core.foreach import Foreach
 from reflex.components.core.foreach import Foreach
 from reflex.components.el.elements.typography import Li, Ol, Ul
 from reflex.components.el.elements.typography import Li, Ol, Ul
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.components.radix.themes.typography.text import Text
 from reflex.components.radix.themes.typography.text import Text
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
@@ -36,7 +37,7 @@ LiteralListStyleTypeOrdered = Literal[
 ]
 ]
 
 
 
 
-class BaseList(Component):
+class BaseList(Component, MarkdownComponentMap):
     """Base class for ordered and unordered lists."""
     """Base class for ordered and unordered lists."""
 
 
     tag = "ul"
     tag = "ul"
@@ -154,7 +155,7 @@ class OrderedList(BaseList, Ol):
         )
         )
 
 
 
 
-class ListItem(Li):
+class ListItem(Li, MarkdownComponentMap):
     """Display an item of an ordered or unordered list."""
     """Display an item of an ordered or unordered list."""
 
 
     @classmethod
     @classmethod

+ 3 - 2
reflex/components/radix/themes/layout/list.pyi

@@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, Literal, Optional, Union, overload
 
 
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.el.elements.typography import Li, Ol, Ul
 from reflex.components.el.elements.typography import Li, Ol, Ul
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.event import BASE_STATE, EventType
 from reflex.event import BASE_STATE, EventType
 from reflex.style import Style
 from reflex.style import Style
 from reflex.vars.base import Var
 from reflex.vars.base import Var
@@ -29,7 +30,7 @@ LiteralListStyleTypeOrdered = Literal[
     "katakana",
     "katakana",
 ]
 ]
 
 
-class BaseList(Component):
+class BaseList(Component, MarkdownComponentMap):
     @overload
     @overload
     @classmethod
     @classmethod
     def create(  # type: ignore
     def create(  # type: ignore
@@ -393,7 +394,7 @@ class OrderedList(BaseList, Ol):
         """
         """
         ...
         ...
 
 
-class ListItem(Li):
+class ListItem(Li, MarkdownComponentMap):
     @overload
     @overload
     @classmethod
     @classmethod
     def create(  # type: ignore
     def create(  # type: ignore

+ 2 - 1
reflex/components/radix/themes/typography/code.py

@@ -7,13 +7,14 @@ from __future__ import annotations
 
 
 from reflex.components.core.breakpoints import Responsive
 from reflex.components.core.breakpoints import Responsive
 from reflex.components.el import elements
 from reflex.components.el import elements
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
 from ..base import LiteralAccentColor, LiteralVariant, RadixThemesComponent
 from ..base import LiteralAccentColor, LiteralVariant, RadixThemesComponent
 from .base import LiteralTextSize, LiteralTextWeight
 from .base import LiteralTextSize, LiteralTextWeight
 
 
 
 
-class Code(elements.Code, RadixThemesComponent):
+class Code(elements.Code, RadixThemesComponent, MarkdownComponentMap):
     """A block level extended quotation."""
     """A block level extended quotation."""
 
 
     tag = "Code"
     tag = "Code"

+ 2 - 1
reflex/components/radix/themes/typography/code.pyi

@@ -7,13 +7,14 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 
 
 from reflex.components.core.breakpoints import Breakpoints
 from reflex.components.core.breakpoints import Breakpoints
 from reflex.components.el import elements
 from reflex.components.el import elements
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.event import BASE_STATE, EventType
 from reflex.event import BASE_STATE, EventType
 from reflex.style import Style
 from reflex.style import Style
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
 from ..base import RadixThemesComponent
 from ..base import RadixThemesComponent
 
 
-class Code(elements.Code, RadixThemesComponent):
+class Code(elements.Code, RadixThemesComponent, MarkdownComponentMap):
     @overload
     @overload
     @classmethod
     @classmethod
     def create(  # type: ignore
     def create(  # type: ignore

+ 2 - 1
reflex/components/radix/themes/typography/heading.py

@@ -7,13 +7,14 @@ from __future__ import annotations
 
 
 from reflex.components.core.breakpoints import Responsive
 from reflex.components.core.breakpoints import Responsive
 from reflex.components.el import elements
 from reflex.components.el import elements
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
 from ..base import LiteralAccentColor, RadixThemesComponent
 from ..base import LiteralAccentColor, RadixThemesComponent
 from .base import LiteralTextAlign, LiteralTextSize, LiteralTextTrim, LiteralTextWeight
 from .base import LiteralTextAlign, LiteralTextSize, LiteralTextTrim, LiteralTextWeight
 
 
 
 
-class Heading(elements.H1, RadixThemesComponent):
+class Heading(elements.H1, RadixThemesComponent, MarkdownComponentMap):
     """A foundational text primitive based on the <span> element."""
     """A foundational text primitive based on the <span> element."""
 
 
     tag = "Heading"
     tag = "Heading"

+ 2 - 1
reflex/components/radix/themes/typography/heading.pyi

@@ -7,13 +7,14 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 
 
 from reflex.components.core.breakpoints import Breakpoints
 from reflex.components.core.breakpoints import Breakpoints
 from reflex.components.el import elements
 from reflex.components.el import elements
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.event import BASE_STATE, EventType
 from reflex.event import BASE_STATE, EventType
 from reflex.style import Style
 from reflex.style import Style
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
 from ..base import RadixThemesComponent
 from ..base import RadixThemesComponent
 
 
-class Heading(elements.H1, RadixThemesComponent):
+class Heading(elements.H1, RadixThemesComponent, MarkdownComponentMap):
     @overload
     @overload
     @classmethod
     @classmethod
     def create(  # type: ignore
     def create(  # type: ignore

+ 2 - 1
reflex/components/radix/themes/typography/link.py

@@ -12,6 +12,7 @@ from reflex.components.core.breakpoints import Responsive
 from reflex.components.core.colors import color
 from reflex.components.core.colors import color
 from reflex.components.core.cond import cond
 from reflex.components.core.cond import cond
 from reflex.components.el.elements.inline import A
 from reflex.components.el.elements.inline import A
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.components.next.link import NextLink
 from reflex.components.next.link import NextLink
 from reflex.utils.imports import ImportDict
 from reflex.utils.imports import ImportDict
 from reflex.vars.base import Var
 from reflex.vars.base import Var
@@ -24,7 +25,7 @@ LiteralLinkUnderline = Literal["auto", "hover", "always", "none"]
 next_link = NextLink.create()
 next_link = NextLink.create()
 
 
 
 
-class Link(RadixThemesComponent, A, MemoizationLeaf):
+class Link(RadixThemesComponent, A, MemoizationLeaf, MarkdownComponentMap):
     """A semantic element for navigation between pages."""
     """A semantic element for navigation between pages."""
 
 
     tag = "Link"
     tag = "Link"

+ 2 - 1
reflex/components/radix/themes/typography/link.pyi

@@ -8,6 +8,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.components.component import MemoizationLeaf
 from reflex.components.component import MemoizationLeaf
 from reflex.components.core.breakpoints import Breakpoints
 from reflex.components.core.breakpoints import Breakpoints
 from reflex.components.el.elements.inline import A
 from reflex.components.el.elements.inline import A
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.components.next.link import NextLink
 from reflex.components.next.link import NextLink
 from reflex.event import BASE_STATE, EventType
 from reflex.event import BASE_STATE, EventType
 from reflex.style import Style
 from reflex.style import Style
@@ -19,7 +20,7 @@ from ..base import RadixThemesComponent
 LiteralLinkUnderline = Literal["auto", "hover", "always", "none"]
 LiteralLinkUnderline = Literal["auto", "hover", "always", "none"]
 next_link = NextLink.create()
 next_link = NextLink.create()
 
 
-class Link(RadixThemesComponent, A, MemoizationLeaf):
+class Link(RadixThemesComponent, A, MemoizationLeaf, MarkdownComponentMap):
     def add_imports(self) -> ImportDict: ...
     def add_imports(self) -> ImportDict: ...
     @overload
     @overload
     @classmethod
     @classmethod

+ 2 - 1
reflex/components/radix/themes/typography/text.py

@@ -10,6 +10,7 @@ from typing import Literal
 from reflex.components.component import ComponentNamespace
 from reflex.components.component import ComponentNamespace
 from reflex.components.core.breakpoints import Responsive
 from reflex.components.core.breakpoints import Responsive
 from reflex.components.el import elements
 from reflex.components.el import elements
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
 from ..base import LiteralAccentColor, RadixThemesComponent
 from ..base import LiteralAccentColor, RadixThemesComponent
@@ -37,7 +38,7 @@ LiteralType = Literal[
 ]
 ]
 
 
 
 
-class Text(elements.Span, RadixThemesComponent):
+class Text(elements.Span, RadixThemesComponent, MarkdownComponentMap):
     """A foundational text primitive based on the <span> element."""
     """A foundational text primitive based on the <span> element."""
 
 
     tag = "Text"
     tag = "Text"

+ 2 - 1
reflex/components/radix/themes/typography/text.pyi

@@ -8,6 +8,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.components.component import ComponentNamespace
 from reflex.components.component import ComponentNamespace
 from reflex.components.core.breakpoints import Breakpoints
 from reflex.components.core.breakpoints import Breakpoints
 from reflex.components.el import elements
 from reflex.components.el import elements
+from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.event import BASE_STATE, EventType
 from reflex.event import BASE_STATE, EventType
 from reflex.style import Style
 from reflex.style import Style
 from reflex.vars.base import Var
 from reflex.vars.base import Var
@@ -35,7 +36,7 @@ LiteralType = Literal[
     "sup",
     "sup",
 ]
 ]
 
 
-class Text(elements.Span, RadixThemesComponent):
+class Text(elements.Span, RadixThemesComponent, MarkdownComponentMap):
     @overload
     @overload
     @classmethod
     @classmethod
     def create(  # type: ignore
     def create(  # type: ignore

+ 2 - 1
reflex/event.py

@@ -45,6 +45,7 @@ from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.function import (
 from reflex.vars.function import (
     ArgsFunctionOperation,
     ArgsFunctionOperation,
+    FunctionArgs,
     FunctionStringVar,
     FunctionStringVar,
     FunctionVar,
     FunctionVar,
     VarOperationCall,
     VarOperationCall,
@@ -1643,7 +1644,7 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
             _js_expr="",
             _js_expr="",
             _var_type=EventChain,
             _var_type=EventChain,
             _var_data=_var_data,
             _var_data=_var_data,
-            _args_names=arg_def,
+            _args=FunctionArgs(arg_def),
             _return_expr=invocation.call(
             _return_expr=invocation.call(
                 LiteralVar.create([LiteralVar.create(event) for event in value.events]),
                 LiteralVar.create([LiteralVar.create(event) for event in value.events]),
                 arg_def_expr,
                 arg_def_expr,

+ 58 - 5
reflex/vars/function.py

@@ -4,8 +4,9 @@ from __future__ import annotations
 
 
 import dataclasses
 import dataclasses
 import sys
 import sys
-from typing import Any, Callable, Optional, Tuple, Type, Union
+from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
 
 
+from reflex.utils import format
 from reflex.utils.types import GenericType
 from reflex.utils.types import GenericType
 
 
 from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
 from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
@@ -126,6 +127,36 @@ class VarOperationCall(CachedVarOperation, Var):
         )
         )
 
 
 
 
+@dataclasses.dataclass(frozen=True)
+class DestructuredArg:
+    """Class for destructured arguments."""
+
+    fields: Tuple[str, ...] = tuple()
+    rest: Optional[str] = None
+
+    def to_javascript(self) -> str:
+        """Convert the destructured argument to JavaScript.
+
+        Returns:
+            The destructured argument in JavaScript.
+        """
+        return format.wrap(
+            ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
+            "{",
+            "}",
+        )
+
+
+@dataclasses.dataclass(
+    frozen=True,
+)
+class FunctionArgs:
+    """Class for function arguments."""
+
+    args: Tuple[Union[str, DestructuredArg], ...] = tuple()
+    rest: Optional[str] = None
+
+
 @dataclasses.dataclass(
 @dataclasses.dataclass(
     eq=False,
     eq=False,
     frozen=True,
     frozen=True,
@@ -134,8 +165,9 @@ class VarOperationCall(CachedVarOperation, Var):
 class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
 class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
     """Base class for immutable function defined via arguments and return expression."""
     """Base class for immutable function defined via arguments and return expression."""
 
 
-    _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
+    _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
+    _explicit_return: bool = dataclasses.field(default=False)
 
 
     @cached_property_no_lock
     @cached_property_no_lock
     def _cached_var_name(self) -> str:
     def _cached_var_name(self) -> str:
@@ -144,13 +176,31 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         Returns:
         Returns:
             The name of the var.
             The name of the var.
         """
         """
-        return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
+        arg_names_str = ", ".join(
+            [
+                arg if isinstance(arg, str) else arg.to_javascript()
+                for arg in self._args.args
+            ]
+        ) + (f", ...{self._args.rest}" if self._args.rest else "")
+
+        return_expr_str = str(LiteralVar.create(self._return_expr))
+
+        # Wrap return expression in curly braces if explicit return syntax is used.
+        return_expr_str_wrapped = (
+            format.wrap(return_expr_str, "{", "}")
+            if self._explicit_return
+            else return_expr_str
+        )
+
+        return f"(({arg_names_str}) => {return_expr_str_wrapped})"
 
 
     @classmethod
     @classmethod
     def create(
     def create(
         cls,
         cls,
-        args_names: Tuple[str, ...],
+        args_names: Sequence[Union[str, DestructuredArg]],
         return_expr: Var | Any,
         return_expr: Var | Any,
+        rest: str | None = None,
+        explicit_return: bool = False,
         _var_type: GenericType = Callable,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
     ) -> ArgsFunctionOperation:
     ) -> ArgsFunctionOperation:
@@ -159,6 +209,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         Args:
         Args:
             args_names: The names of the arguments.
             args_names: The names of the arguments.
             return_expr: The return expression of the function.
             return_expr: The return expression of the function.
+            rest: The name of the rest argument.
+            explicit_return: Whether to use explicit return syntax.
             _var_data: Additional hooks and imports associated with the Var.
             _var_data: Additional hooks and imports associated with the Var.
 
 
         Returns:
         Returns:
@@ -168,8 +220,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             _js_expr="",
             _js_expr="",
             _var_type=_var_type,
             _var_type=_var_type,
             _var_data=_var_data,
             _var_data=_var_data,
-            _args_names=args_names,
+            _args=FunctionArgs(args=tuple(args_names), rest=rest),
             _return_expr=return_expr,
             _return_expr=return_expr,
+            _explicit_return=explicit_return,
         )
         )
 
 
 
 

+ 3 - 3
tests/units/components/base/test_script.py

@@ -62,14 +62,14 @@ def test_script_event_handler():
     )
     )
     render_dict = component.render()
     render_dict = component.render()
     assert (
     assert (
-        f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{  }}), ({{  }})))], args, ({{  }})))))}}'
+        f'onReady={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{  }}), ({{  }})))], args, ({{  }}))))}}'
         in render_dict["props"]
         in render_dict["props"]
     )
     )
     assert (
     assert (
-        f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{  }}), ({{  }})))], args, ({{  }})))))}}'
+        f'onLoad={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_load", ({{  }}), ({{  }})))], args, ({{  }}))))}}'
         in render_dict["props"]
         in render_dict["props"]
     )
     )
     assert (
     assert (
-        f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{  }}), ({{  }})))], args, ({{  }})))))}}'
+        f'onError={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_error", ({{  }}), ({{  }})))], args, ({{  }}))))}}'
         in render_dict["props"]
         in render_dict["props"]
     )
     )

+ 0 - 19
tests/units/components/datadisplay/test_code.py

@@ -11,22 +11,3 @@ def test_code_light_dark_theme(theme, expected):
     code_block = CodeBlock.create(theme=theme)
     code_block = CodeBlock.create(theme=theme)
 
 
     assert code_block.theme._js_expr == expected  # type: ignore
     assert code_block.theme._js_expr == expected  # type: ignore
-
-
-def generate_custom_code(language, expected_case):
-    return f"SyntaxHighlighter.registerLanguage('{language}', {expected_case})"
-
-
-@pytest.mark.parametrize(
-    "language, expected_case",
-    [
-        ("python", "python"),
-        ("firestore-security-rules", "firestoreSecurityRules"),
-        ("typescript", "typescript"),
-    ],
-)
-def test_get_custom_code(language, expected_case):
-    code_block = CodeBlock.create(language=language)
-    assert code_block._get_custom_code() == generate_custom_code(
-        language, expected_case
-    )

+ 0 - 0
tests/units/components/markdown/__init__.py


+ 190 - 0
tests/units/components/markdown/test_markdown.py

@@ -0,0 +1,190 @@
+from typing import Type
+
+import pytest
+
+from reflex.components.component import Component, memo
+from reflex.components.datadisplay.code import CodeBlock
+from reflex.components.datadisplay.shiki_code_block import ShikiHighLevelCodeBlock
+from reflex.components.markdown.markdown import Markdown, MarkdownComponentMap
+from reflex.components.radix.themes.layout.box import Box
+from reflex.components.radix.themes.typography.heading import Heading
+from reflex.vars.base import Var
+
+
+class CustomMarkdownComponent(Component, MarkdownComponentMap):
+    """A custom markdown component."""
+
+    tag = "CustomMarkdownComponent"
+    library = "custom"
+
+    @classmethod
+    def get_fn_args(cls) -> tuple[str, ...]:
+        """Return the function arguments.
+
+        Returns:
+            The function arguments.
+        """
+        return ("custom_node", "custom_children", "custom_props")
+
+    @classmethod
+    def get_fn_body(cls) -> Var:
+        """Return the function body.
+
+        Returns:
+            The function body.
+        """
+        return Var(_js_expr="{return custom_node + custom_children + custom_props}")
+
+
+def syntax_highlighter_memoized_component(codeblock: Type[Component]):
+    @memo
+    def code_block(code: str, language: str):
+        return Box.create(
+            codeblock.create(
+                code,
+                language=language,
+                class_name="code-block",
+                can_copy=True,
+            ),
+            class_name="relative mb-4",
+        )
+
+    def code_block_markdown(*children, **props):
+        return code_block(
+            code=children[0], language=props.pop("language", "plain"), **props
+        )
+
+    return code_block_markdown
+
+
+@pytest.mark.parametrize(
+    "fn_body, fn_args, explicit_return, expected",
+    [
+        (
+            None,
+            None,
+            False,
+            Var(_js_expr="(({node, children, ...props}) => undefined)"),
+        ),
+        ("return node", ("node",), True, Var(_js_expr="(({node}) => {return node})")),
+        (
+            "return node + children",
+            ("node", "children"),
+            True,
+            Var(_js_expr="(({node, children}) => {return node + children})"),
+        ),
+        (
+            "return node + props",
+            ("node", "...props"),
+            True,
+            Var(_js_expr="(({node, ...props}) => {return node + props})"),
+        ),
+        (
+            "return node + children + props",
+            ("node", "children", "...props"),
+            True,
+            Var(
+                _js_expr="(({node, children, ...props}) => {return node + children + props})"
+            ),
+        ),
+    ],
+)
+def test_create_map_fn_var(fn_body, fn_args, explicit_return, expected):
+    result = MarkdownComponentMap.create_map_fn_var(
+        fn_body=Var(_js_expr=fn_body, _var_type=str) if fn_body else None,
+        fn_args=fn_args,
+        explicit_return=explicit_return,
+    )
+    assert result._js_expr == expected._js_expr
+
+
+@pytest.mark.parametrize(
+    ("cls", "fn_body", "fn_args", "explicit_return", "expected"),
+    [
+        (
+            MarkdownComponentMap,
+            None,
+            None,
+            False,
+            Var(_js_expr="(({node, children, ...props}) => undefined)"),
+        ),
+        (
+            MarkdownComponentMap,
+            "return node",
+            ("node",),
+            True,
+            Var(_js_expr="(({node}) => {return node})"),
+        ),
+        (
+            CustomMarkdownComponent,
+            None,
+            None,
+            True,
+            Var(
+                _js_expr="(({custom_node, custom_children, custom_props}) => {return custom_node + custom_children + custom_props})"
+            ),
+        ),
+        (
+            CustomMarkdownComponent,
+            "return custom_node",
+            ("custom_node",),
+            True,
+            Var(_js_expr="(({custom_node}) => {return custom_node})"),
+        ),
+    ],
+)
+def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expected):
+    result = cls.create_map_fn_var(
+        fn_body=Var(_js_expr=fn_body, _var_type=int) if fn_body else None,
+        fn_args=fn_args,
+        explicit_return=explicit_return,
+    )
+    assert result._js_expr == expected._js_expr
+
+
+@pytest.mark.parametrize(
+    "key,component_map, expected",
+    [
+        (
+            "code",
+            {},
+            """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?<lang>.*)/); const _language = match ? match[1] : '';   if (_language) {     (async () => {       try {         const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`);         SyntaxHighlighter.registerLanguage(_language, module.default);       } catch (error) {         console.error(`Error importing language module for ${_language}:`, error);       }     })();   }  ;             return inline ? (                 <RadixThemesCode {...props}>{children}</RadixThemesCode>             ) : (                 <SyntaxHighlighter children={((Array.isArray(children)) ? children.join("\\n") : children)} css={({ ["marginTop"] : "1em", ["marginBottom"] : "1em" })} customStyle={({ ["marginTop"] : "1em", ["marginBottom"] : "1em" })} language={_language} style={((resolvedColorMode === "light") ? oneLight : oneDark)} wrapLongLines={true} {...props}/>             );         })""",
+        ),
+        (
+            "code",
+            {
+                "codeblock": lambda value, **props: ShikiHighLevelCodeBlock.create(
+                    value, **props
+                )
+            },
+            """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?<lang>.*)/); const _language = match ? match[1] : '';  ;             return inline ? (                 <RadixThemesCode {...props}>{children}</RadixThemesCode>             ) : (                 <RadixThemesBox css={({ ["pre"] : ({ ["margin"] : "0", ["padding"] : "24px", ["background"] : "transparent", ["overflow-x"] : "auto", ["border-radius"] : "6px" }) })} {...props}><ShikiCode code={((Array.isArray(children)) ? children.join("\\n") : children)} decorations={[]} language={_language} theme={((resolvedColorMode === "light") ? "one-light" : "one-dark-pro")} transformers={[]}/></RadixThemesBox>             );         })""",
+        ),
+        (
+            "h1",
+            {
+                "h1": lambda value: CustomMarkdownComponent.create(
+                    Heading.create(value, as_="h1", size="6", margin_y="0.5em")
+                )
+            },
+            """(({custom_node, custom_children, custom_props}) => (<CustomMarkdownComponent {...props}><RadixThemesHeading as={"h1"} css={({ ["marginTop"] : "0.5em", ["marginBottom"] : "0.5em" })} size={"6"}>{children}</RadixThemesHeading></CustomMarkdownComponent>))""",
+        ),
+        (
+            "code",
+            {"codeblock": syntax_highlighter_memoized_component(CodeBlock)},
+            """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?<lang>.*)/); const _language = match ? match[1] : '';   if (_language) {     (async () => {       try {         const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`);         SyntaxHighlighter.registerLanguage(_language, module.default);       } catch (error) {         console.error(`Error importing language module for ${_language}:`, error);       }     })();   }  ;             return inline ? (                 <RadixThemesCode {...props}>{children}</RadixThemesCode>             ) : (                 <CodeBlock code={((Array.isArray(children)) ? children.join("\\n") : children)} language={_language} {...props}/>             );         })""",
+        ),
+        (
+            "code",
+            {
+                "codeblock": syntax_highlighter_memoized_component(
+                    ShikiHighLevelCodeBlock
+                )
+            },
+            """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?<lang>.*)/); const _language = match ? match[1] : '';  ;             return inline ? (                 <RadixThemesCode {...props}>{children}</RadixThemesCode>             ) : (                 <CodeBlock code={((Array.isArray(children)) ? children.join("\\n") : children)} language={_language} {...props}/>             );         })""",
+        ),
+    ],
+)
+def test_markdown_format_component(key, component_map, expected):
+    markdown = Markdown.create("# header", component_map=component_map)
+    result = markdown.format_component_map()
+    assert str(result[key]) == expected

+ 2 - 2
tests/units/components/test_component.py

@@ -844,9 +844,9 @@ def test_component_event_trigger_arbitrary_args():
     comp = C1.create(on_foo=C1State.mock_handler)
     comp = C1.create(on_foo=C1State.mock_handler)
 
 
     assert comp.render()["props"][0] == (
     assert comp.render()["props"][0] == (
-        "onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents("
+        "onFoo={((__e, _alpha, _bravo, _charlie) => (addEvents("
         f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{  }})))], '
         f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{  }})))], '
-        "[__e, _alpha, _bravo, _charlie], ({  })))))}"
+        "[__e, _alpha, _bravo, _charlie], ({  }))))}"
     )
     )
 
 
 
 

+ 6 - 6
tests/units/test_event.py

@@ -222,16 +222,16 @@ def test_event_console_log():
     assert spec.handler.fn.__qualname__ == "_call_function"
     assert spec.handler.fn.__qualname__ == "_call_function"
     assert spec.args[0][0].equals(Var(_js_expr="function"))
     assert spec.args[0][0].equals(Var(_js_expr="function"))
     assert spec.args[0][1].equals(
     assert spec.args[0][1].equals(
-        Var('(() => ((console["log"]("message"))))', _var_type=Callable)
+        Var('(() => (console["log"]("message")))', _var_type=Callable)
     )
     )
     assert (
     assert (
         format.format_event(spec)
         format.format_event(spec)
-        == 'Event("_call_function", {function:(() => ((console["log"]("message"))))})'
+        == 'Event("_call_function", {function:(() => (console["log"]("message")))})'
     )
     )
     spec = event.console_log(Var(_js_expr="message"))
     spec = event.console_log(Var(_js_expr="message"))
     assert (
     assert (
         format.format_event(spec)
         format.format_event(spec)
-        == 'Event("_call_function", {function:(() => ((console["log"](message))))})'
+        == 'Event("_call_function", {function:(() => (console["log"](message)))})'
     )
     )
 
 
 
 
@@ -242,16 +242,16 @@ def test_event_window_alert():
     assert spec.handler.fn.__qualname__ == "_call_function"
     assert spec.handler.fn.__qualname__ == "_call_function"
     assert spec.args[0][0].equals(Var(_js_expr="function"))
     assert spec.args[0][0].equals(Var(_js_expr="function"))
     assert spec.args[0][1].equals(
     assert spec.args[0][1].equals(
-        Var('(() => ((window["alert"]("message"))))', _var_type=Callable)
+        Var('(() => (window["alert"]("message")))', _var_type=Callable)
     )
     )
     assert (
     assert (
         format.format_event(spec)
         format.format_event(spec)
-        == 'Event("_call_function", {function:(() => ((window["alert"]("message"))))})'
+        == 'Event("_call_function", {function:(() => (window["alert"]("message")))})'
     )
     )
     spec = event.window_alert(Var(_js_expr="message"))
     spec = event.window_alert(Var(_js_expr="message"))
     assert (
     assert (
         format.format_event(spec)
         format.format_event(spec)
-        == 'Event("_call_function", {function:(() => ((window["alert"](message))))})'
+        == 'Event("_call_function", {function:(() => (window["alert"](message)))})'
     )
     )
 
 
 
 

+ 24 - 4
tests/units/test_var.py

@@ -22,7 +22,11 @@ from reflex.vars.base import (
     var_operation,
     var_operation,
     var_operation_return,
     var_operation_return,
 )
 )
-from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
+from reflex.vars.function import (
+    ArgsFunctionOperation,
+    DestructuredArg,
+    FunctionStringVar,
+)
 from reflex.vars.number import LiteralBooleanVar, LiteralNumberVar, NumberVar
 from reflex.vars.number import LiteralBooleanVar, LiteralNumberVar, NumberVar
 from reflex.vars.object import LiteralObjectVar, ObjectVar
 from reflex.vars.object import LiteralObjectVar, ObjectVar
 from reflex.vars.sequence import (
 from reflex.vars.sequence import (
@@ -921,13 +925,13 @@ def test_function_var():
     )
     )
     assert (
     assert (
         str(manual_addition_func.call(1, 2))
         str(manual_addition_func.call(1, 2))
-        == '(((a, b) => (({ ["args"] : [a, b], ["result"] : a + b })))(1, 2))'
+        == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
     )
     )
 
 
     increment_func = addition_func(1)
     increment_func = addition_func(1)
     assert (
     assert (
         str(increment_func.call(2))
         str(increment_func.call(2))
-        == "(((...args) => ((((a, b) => a + b)(1, ...args))))(2))"
+        == "(((...args) => (((a, b) => a + b)(1, ...args)))(2))"
     )
     )
 
 
     create_hello_statement = ArgsFunctionOperation.create(
     create_hello_statement = ArgsFunctionOperation.create(
@@ -937,8 +941,24 @@ def test_function_var():
     last_name = LiteralStringVar.create("Universe")
     last_name = LiteralStringVar.create("Universe")
     assert (
     assert (
         str(create_hello_statement.call(f"{first_name} {last_name}"))
         str(create_hello_statement.call(f"{first_name} {last_name}"))
-        == '(((name) => (("Hello, "+name+"!")))("Steven Universe"))'
+        == '(((name) => ("Hello, "+name+"!"))("Steven Universe"))'
+    )
+
+    # Test with destructured arguments
+    destructured_func = ArgsFunctionOperation.create(
+        (DestructuredArg(fields=("a", "b")),),
+        Var(_js_expr="a + b"),
+    )
+    assert (
+        str(destructured_func.call({"a": 1, "b": 2}))
+        == '((({a, b}) => a + b)(({ ["a"] : 1, ["b"] : 2 })))'
+    )
+
+    # Test with explicit return
+    explicit_return_func = ArgsFunctionOperation.create(
+        ("a", "b"), Var(_js_expr="return a + b"), explicit_return=True
     )
     )
+    assert str(explicit_return_func.call(1, 2)) == "(((a, b) => {return a + b})(1, 2))"
 
 
 
 
 def test_var_operation():
 def test_var_operation():

+ 5 - 5
tests/units/utils/test_format.py

@@ -374,7 +374,7 @@ def test_format_match(
                 events=[EventSpec(handler=EventHandler(fn=mock_event))],
                 events=[EventSpec(handler=EventHandler(fn=mock_event))],
                 args_spec=lambda: [],
                 args_spec=lambda: [],
             ),
             ),
-            '((...args) => ((addEvents([(Event("mock_event", ({  }), ({  })))], args, ({  })))))',
+            '((...args) => (addEvents([(Event("mock_event", ({  }), ({  })))], args, ({  }))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -395,7 +395,7 @@ def test_format_match(
                 ],
                 ],
                 args_spec=lambda e: [e.target.value],
                 args_spec=lambda e: [e.target.value],
             ),
             ),
-            '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({  })))], [_e], ({  })))))',
+            '((_e) => (addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({  })))], [_e], ({  }))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -403,7 +403,7 @@ def test_format_match(
                 args_spec=lambda: [],
                 args_spec=lambda: [],
                 event_actions={"stopPropagation": True},
                 event_actions={"stopPropagation": True},
             ),
             ),
-            '((...args) => ((addEvents([(Event("mock_event", ({  }), ({  })))], args, ({ ["stopPropagation"] : true })))))',
+            '((...args) => (addEvents([(Event("mock_event", ({  }), ({  })))], args, ({ ["stopPropagation"] : true }))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -415,7 +415,7 @@ def test_format_match(
                 ],
                 ],
                 args_spec=lambda: [],
                 args_spec=lambda: [],
             ),
             ),
-            '((...args) => ((addEvents([(Event("mock_event", ({  }), ({ ["stopPropagation"] : true })))], args, ({  })))))',
+            '((...args) => (addEvents([(Event("mock_event", ({  }), ({ ["stopPropagation"] : true })))], args, ({  }))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -423,7 +423,7 @@ def test_format_match(
                 args_spec=lambda: [],
                 args_spec=lambda: [],
                 event_actions={"preventDefault": True},
                 event_actions={"preventDefault": True},
             ),
             ),
-            '((...args) => ((addEvents([(Event("mock_event", ({  }), ({  })))], args, ({ ["preventDefault"] : true })))))',
+            '((...args) => (addEvents([(Event("mock_event", ({  }), ({  })))], args, ({ ["preventDefault"] : true }))))',
         ),
         ),
         ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'),
         ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'),
         (Var(_js_expr="var", _var_type=int).guess_type(), "var"),
         (Var(_js_expr="var", _var_type=int).guess_type(), "var"),