markdown.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. """Markdown component."""
  2. from __future__ import annotations
  3. import textwrap
  4. from functools import lru_cache
  5. from hashlib import md5
  6. from typing import Any, Callable, Dict, Union
  7. from reflex.components.component import Component, CustomComponent
  8. from reflex.components.radix.themes.layout.list import (
  9. ListItem,
  10. OrderedList,
  11. UnorderedList,
  12. )
  13. from reflex.components.radix.themes.typography.heading import Heading
  14. from reflex.components.radix.themes.typography.link import Link
  15. from reflex.components.radix.themes.typography.text import Text
  16. from reflex.components.tags.tag import Tag
  17. from reflex.utils import types
  18. from reflex.utils.imports import ImportDict, ImportVar
  19. from reflex.vars.base import LiteralVar, Var
  20. from reflex.vars.function import ARRAY_ISARRAY
  21. from reflex.vars.number import ternary_operation
  22. # Special vars used in the component map.
  23. _CHILDREN = Var(_js_expr="children", _var_type=str)
  24. _PROPS = Var(_js_expr="...props")
  25. _PROPS_IN_TAG = Var(_js_expr="{...props}")
  26. _MOCK_ARG = Var(_js_expr="", _var_type=str)
  27. # Special remark plugins.
  28. _REMARK_MATH = Var(_js_expr="remarkMath")
  29. _REMARK_GFM = Var(_js_expr="remarkGfm")
  30. _REMARK_UNWRAP_IMAGES = Var(_js_expr="remarkUnwrapImages")
  31. _REMARK_PLUGINS = LiteralVar.create([_REMARK_MATH, _REMARK_GFM, _REMARK_UNWRAP_IMAGES])
  32. # Special rehype plugins.
  33. _REHYPE_KATEX = Var(_js_expr="rehypeKatex")
  34. _REHYPE_RAW = Var(_js_expr="rehypeRaw")
  35. _REHYPE_PLUGINS = LiteralVar.create([_REHYPE_KATEX, _REHYPE_RAW])
  36. # These tags do NOT get props passed to them
  37. NO_PROPS_TAGS = ("ul", "ol", "li")
  38. # Component Mapping
  39. @lru_cache
  40. def get_base_component_map() -> dict[str, Callable]:
  41. """Get the base component map.
  42. Returns:
  43. The base component map.
  44. """
  45. from reflex.components.datadisplay.code import CodeBlock
  46. from reflex.components.radix.themes.typography.code import Code
  47. return {
  48. "h1": lambda value: Heading.create(value, as_="h1", size="6", margin_y="0.5em"),
  49. "h2": lambda value: Heading.create(value, as_="h2", size="5", margin_y="0.5em"),
  50. "h3": lambda value: Heading.create(value, as_="h3", size="4", margin_y="0.5em"),
  51. "h4": lambda value: Heading.create(value, as_="h4", size="3", margin_y="0.5em"),
  52. "h5": lambda value: Heading.create(value, as_="h5", size="2", margin_y="0.5em"),
  53. "h6": lambda value: Heading.create(value, as_="h6", size="1", margin_y="0.5em"),
  54. "p": lambda value: Text.create(value, margin_y="1em"),
  55. "ul": lambda value: UnorderedList.create(value, margin_y="1em"), # type: ignore
  56. "ol": lambda value: OrderedList.create(value, margin_y="1em"), # type: ignore
  57. "li": lambda value: ListItem.create(value, margin_y="0.5em"),
  58. "a": lambda value: Link.create(value),
  59. "code": lambda value: Code.create(value),
  60. "codeblock": lambda value, **props: CodeBlock.create(
  61. value, margin_y="1em", wrap_long_lines=True, **props
  62. ),
  63. }
  64. class Markdown(Component):
  65. """A markdown component."""
  66. library = "react-markdown@8.0.7"
  67. tag = "ReactMarkdown"
  68. is_default = True
  69. # The component map from a tag to a lambda that creates a component.
  70. component_map: Dict[str, Any] = {}
  71. # The hash of the component map, generated at create() time.
  72. component_map_hash: str = ""
  73. @classmethod
  74. def create(cls, *children, **props) -> Component:
  75. """Create a markdown component.
  76. Args:
  77. *children: The children of the component.
  78. **props: The properties of the component.
  79. Raises:
  80. ValueError: If the children are not valid.
  81. Returns:
  82. The markdown component.
  83. """
  84. if len(children) != 1 or not types._isinstance(children[0], Union[str, Var]):
  85. raise ValueError(
  86. "Markdown component must have exactly one child containing the markdown source."
  87. )
  88. # Update the base component map with the custom component map.
  89. component_map = {**get_base_component_map(), **props.pop("component_map", {})}
  90. # Get the markdown source.
  91. src = children[0]
  92. # Dedent the source.
  93. if isinstance(src, str):
  94. src = textwrap.dedent(src)
  95. # Create the component.
  96. return super().create(
  97. src,
  98. component_map=component_map,
  99. component_map_hash=cls._component_map_hash(component_map),
  100. **props,
  101. )
  102. def _get_all_custom_components(
  103. self, seen: set[str] | None = None
  104. ) -> set[CustomComponent]:
  105. """Get all the custom components used by the component.
  106. Args:
  107. seen: The tags of the components that have already been seen.
  108. Returns:
  109. The set of custom components.
  110. """
  111. custom_components = super()._get_all_custom_components(seen=seen)
  112. # Get the custom components for each tag.
  113. for component in self.component_map.values():
  114. custom_components |= component(_MOCK_ARG)._get_all_custom_components(
  115. seen=seen
  116. )
  117. return custom_components
  118. def add_imports(self) -> ImportDict | list[ImportDict]:
  119. """Add imports for the markdown component.
  120. Returns:
  121. The imports for the markdown component.
  122. """
  123. from reflex.components.datadisplay.code import CodeBlock, Theme
  124. from reflex.components.radix.themes.typography.code import Code
  125. return [
  126. {
  127. "": "katex/dist/katex.min.css",
  128. "remark-math@5.1.1": ImportVar(
  129. tag=_REMARK_MATH._js_expr, is_default=True
  130. ),
  131. "remark-gfm@3.0.1": ImportVar(
  132. tag=_REMARK_GFM._js_expr, is_default=True
  133. ),
  134. "remark-unwrap-images@4.0.0": ImportVar(
  135. tag=_REMARK_UNWRAP_IMAGES._js_expr, is_default=True
  136. ),
  137. "rehype-katex@6.0.3": ImportVar(
  138. tag=_REHYPE_KATEX._js_expr, is_default=True
  139. ),
  140. "rehype-raw@6.1.1": ImportVar(
  141. tag=_REHYPE_RAW._js_expr, is_default=True
  142. ),
  143. },
  144. *[
  145. component(_MOCK_ARG)._get_all_imports() # type: ignore
  146. for component in self.component_map.values()
  147. ],
  148. CodeBlock.create(theme=Theme.light)._get_imports(),
  149. Code.create()._get_imports(),
  150. ]
  151. def get_component(self, tag: str, **props) -> Component:
  152. """Get the component for a tag and props.
  153. Args:
  154. tag: The tag of the component.
  155. **props: The props of the component.
  156. Returns:
  157. The component.
  158. Raises:
  159. ValueError: If the tag is invalid.
  160. """
  161. # Check the tag is valid.
  162. if tag not in self.component_map:
  163. raise ValueError(f"No markdown component found for tag: {tag}.")
  164. special_props = [_PROPS_IN_TAG]
  165. children = [
  166. _CHILDREN
  167. if tag != "codeblock"
  168. # For codeblock, the mapping for some cases returns an array of elements. Let's join them into a string.
  169. else ternary_operation(
  170. ARRAY_ISARRAY.call(_CHILDREN), # type: ignore
  171. _CHILDREN.to(list).join("\n"),
  172. _CHILDREN,
  173. ).to(str)
  174. ]
  175. # For certain tags, the props from the markdown renderer are not actually valid for the component.
  176. if tag in NO_PROPS_TAGS:
  177. special_props = []
  178. # If the children are set as a prop, don't pass them as children.
  179. children_prop = props.pop("children", None)
  180. if children_prop is not None:
  181. special_props.append(Var(_js_expr=f"children={{{str(children_prop)}}}"))
  182. children = []
  183. # Get the component.
  184. component = self.component_map[tag](*children, **props).set(
  185. special_props=special_props
  186. )
  187. return component
  188. def format_component(self, tag: str, **props) -> str:
  189. """Format a component for rendering in the component map.
  190. Args:
  191. tag: The tag of the component.
  192. **props: Extra props to pass to the component function.
  193. Returns:
  194. The formatted component.
  195. """
  196. return str(self.get_component(tag, **props)).replace("\n", "")
  197. def format_component_map(self) -> dict[str, Var]:
  198. """Format the component map for rendering.
  199. Returns:
  200. The formatted component map.
  201. """
  202. components = {
  203. tag: Var(
  204. _js_expr=f"(({{node, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => ({self.format_component(tag)}))"
  205. )
  206. for tag in self.component_map
  207. }
  208. # Separate out inline code and code blocks.
  209. components["code"] = Var(
  210. _js_expr=f"""(({{node, inline, className, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => {{
  211. const match = (className || '').match(/language-(?<lang>.*)/);
  212. const language = match ? match[1] : '';
  213. if (language) {{
  214. (async () => {{
  215. try {{
  216. const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{language}}`);
  217. SyntaxHighlighter.registerLanguage(language, module.default);
  218. }} catch (error) {{
  219. console.error(`Error importing language module for ${{language}}:`, error);
  220. }}
  221. }})();
  222. }}
  223. return inline ? (
  224. {self.format_component("code")}
  225. ) : (
  226. {self.format_component("codeblock", language=Var(_js_expr="language", _var_type=str))}
  227. );
  228. }})""".replace("\n", " ")
  229. )
  230. return components
  231. @staticmethod
  232. def _component_map_hash(component_map) -> str:
  233. inp = str(
  234. {tag: component(_MOCK_ARG) for tag, component in component_map.items()}
  235. ).encode()
  236. return md5(inp).hexdigest()
  237. def _get_component_map_name(self) -> str:
  238. return f"ComponentMap_{self.component_map_hash}"
  239. def _get_custom_code(self) -> str | None:
  240. hooks = set()
  241. for _component in self.component_map.values():
  242. comp = _component(_MOCK_ARG)
  243. hooks.update(comp._get_all_hooks_internal())
  244. hooks.update(comp._get_all_hooks())
  245. formatted_hooks = "\n".join(hooks)
  246. return f"""
  247. function {self._get_component_map_name()} () {{
  248. {formatted_hooks}
  249. return (
  250. {str(LiteralVar.create(self.format_component_map()))}
  251. )
  252. }}
  253. """
  254. def _render(self) -> Tag:
  255. tag = (
  256. super()
  257. ._render()
  258. .add_props(
  259. remark_plugins=_REMARK_PLUGINS,
  260. rehype_plugins=_REHYPE_PLUGINS,
  261. components=Var(_js_expr=f"{self._get_component_map_name()}()"),
  262. )
  263. .remove_props("componentMap", "componentMapHash")
  264. )
  265. return tag