markdown.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. """Markdown component."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import textwrap
  5. from collections.abc import Callable, Sequence
  6. from functools import lru_cache
  7. from hashlib import md5
  8. from typing import Any
  9. from reflex.components.component import BaseComponent, Component, CustomComponent
  10. from reflex.components.tags.tag import Tag
  11. from reflex.utils.imports import ImportDict, ImportVar
  12. from reflex.vars.base import LiteralVar, Var, VarData
  13. from reflex.vars.function import ARRAY_ISARRAY, ArgsFunctionOperation, DestructuredArg
  14. from reflex.vars.number import ternary_operation
  15. # Special vars used in the component map.
  16. _CHILDREN = Var(_js_expr="children", _var_type=str)
  17. _PROPS = Var(_js_expr="...props")
  18. _PROPS_IN_TAG = Var(_js_expr="{...props}")
  19. _MOCK_ARG = Var(_js_expr="", _var_type=str)
  20. _LANGUAGE = Var(_js_expr="_language", _var_type=str)
  21. # Special remark plugins.
  22. _REMARK_MATH = Var(_js_expr="remarkMath")
  23. _REMARK_GFM = Var(_js_expr="remarkGfm")
  24. _REMARK_UNWRAP_IMAGES = Var(_js_expr="remarkUnwrapImages")
  25. _REMARK_PLUGINS = LiteralVar.create([_REMARK_MATH, _REMARK_GFM, _REMARK_UNWRAP_IMAGES])
  26. # Special rehype plugins.
  27. _REHYPE_KATEX = Var(_js_expr="rehypeKatex")
  28. _REHYPE_RAW = Var(_js_expr="rehypeRaw")
  29. _REHYPE_PLUGINS = LiteralVar.create([_REHYPE_KATEX, _REHYPE_RAW])
  30. # These tags do NOT get props passed to them
  31. NO_PROPS_TAGS = ("ul", "ol", "li")
  32. # Component Mapping
  33. @lru_cache
  34. def get_base_component_map() -> dict[str, Callable]:
  35. """Get the base component map.
  36. Returns:
  37. The base component map.
  38. """
  39. from reflex.components.datadisplay.code import CodeBlock
  40. from reflex.components.radix.themes.layout.list import (
  41. ListItem,
  42. OrderedList,
  43. UnorderedList,
  44. )
  45. from reflex.components.radix.themes.typography.code import Code
  46. from reflex.components.radix.themes.typography.heading import Heading
  47. from reflex.components.radix.themes.typography.link import Link
  48. from reflex.components.radix.themes.typography.text import Text
  49. return {
  50. "h1": lambda value: Heading.create(value, as_="h1", size="6", margin_y="0.5em"),
  51. "h2": lambda value: Heading.create(value, as_="h2", size="5", margin_y="0.5em"),
  52. "h3": lambda value: Heading.create(value, as_="h3", size="4", margin_y="0.5em"),
  53. "h4": lambda value: Heading.create(value, as_="h4", size="3", margin_y="0.5em"),
  54. "h5": lambda value: Heading.create(value, as_="h5", size="2", margin_y="0.5em"),
  55. "h6": lambda value: Heading.create(value, as_="h6", size="1", margin_y="0.5em"),
  56. "p": lambda value: Text.create(value, margin_y="1em"),
  57. "ul": lambda value: UnorderedList.create(value, margin_y="1em"),
  58. "ol": lambda value: OrderedList.create(value, margin_y="1em"),
  59. "li": lambda value: ListItem.create(value, margin_y="0.5em"),
  60. "a": lambda value: Link.create(value),
  61. "code": lambda value: Code.create(value),
  62. "codeblock": lambda value, **props: CodeBlock.create(
  63. value, margin_y="1em", wrap_long_lines=True, **props
  64. ),
  65. }
  66. @dataclasses.dataclass()
  67. class MarkdownComponentMap:
  68. """Mixin class for handling custom component maps in Markdown components."""
  69. _explicit_return: bool = dataclasses.field(default=False)
  70. @classmethod
  71. def get_component_map_custom_code(cls) -> Var:
  72. """Get the custom code for the component map.
  73. Returns:
  74. The custom code for the component map.
  75. """
  76. return Var("")
  77. @classmethod
  78. def create_map_fn_var(
  79. cls,
  80. fn_body: Var | None = None,
  81. fn_args: Sequence[str] | None = None,
  82. explicit_return: bool | None = None,
  83. var_data: VarData | None = None,
  84. ) -> Var:
  85. """Create a function Var for the component map.
  86. Args:
  87. fn_body: The formatted component as a string.
  88. fn_args: The function arguments.
  89. explicit_return: Whether to use explicit return syntax.
  90. var_data: The var data for the function.
  91. Returns:
  92. The function Var for the component map.
  93. """
  94. fn_args = fn_args or cls.get_fn_args()
  95. fn_body = fn_body if fn_body is not None else cls.get_fn_body()
  96. explicit_return = explicit_return or cls._explicit_return
  97. return ArgsFunctionOperation.create(
  98. args_names=(DestructuredArg(fields=tuple(fn_args)),),
  99. return_expr=fn_body,
  100. explicit_return=explicit_return,
  101. _var_data=var_data,
  102. )
  103. @classmethod
  104. def get_fn_args(cls) -> Sequence[str]:
  105. """Get the function arguments for the component map.
  106. Returns:
  107. The function arguments as a list of strings.
  108. """
  109. return ["node", _CHILDREN._js_expr, _PROPS._js_expr]
  110. @classmethod
  111. def get_fn_body(cls) -> Var:
  112. """Get the function body for the component map.
  113. Returns:
  114. The function body as a string.
  115. """
  116. return Var(_js_expr="undefined", _var_type=None)
  117. class Markdown(Component):
  118. """A markdown component."""
  119. library = "react-markdown@8.0.7"
  120. tag = "ReactMarkdown"
  121. is_default = True
  122. # The component map from a tag to a lambda that creates a component.
  123. component_map: dict[str, Any] = {}
  124. # The hash of the component map, generated at create() time.
  125. component_map_hash: str = ""
  126. @classmethod
  127. def create(cls, *children, **props) -> Component:
  128. """Create a markdown component.
  129. Args:
  130. *children: The children of the component.
  131. **props: The properties of the component.
  132. Raises:
  133. ValueError: If the children are not valid.
  134. Returns:
  135. The markdown component.
  136. """
  137. if len(children) != 1 or not isinstance(children[0], (str, Var)):
  138. raise ValueError(
  139. "Markdown component must have exactly one child containing the markdown source."
  140. )
  141. # Update the base component map with the custom component map.
  142. component_map = {**get_base_component_map(), **props.pop("component_map", {})}
  143. # Get the markdown source.
  144. src = children[0]
  145. # Dedent the source.
  146. if isinstance(src, str):
  147. src = textwrap.dedent(src)
  148. # Create the component.
  149. return super().create(
  150. src,
  151. component_map=component_map,
  152. component_map_hash=cls._component_map_hash(component_map),
  153. **props,
  154. )
  155. def _get_all_custom_components(
  156. self, seen: set[str] | None = None
  157. ) -> set[CustomComponent]:
  158. """Get all the custom components used by the component.
  159. Args:
  160. seen: The tags of the components that have already been seen.
  161. Returns:
  162. The set of custom components.
  163. """
  164. custom_components = super()._get_all_custom_components(seen=seen)
  165. # Get the custom components for each tag.
  166. for component in self.component_map.values():
  167. custom_components |= component(_MOCK_ARG)._get_all_custom_components(
  168. seen=seen
  169. )
  170. return custom_components
  171. def add_imports(self) -> ImportDict | list[ImportDict]:
  172. """Add imports for the markdown component.
  173. Returns:
  174. The imports for the markdown component.
  175. """
  176. return [
  177. {
  178. "": "katex/dist/katex.min.css",
  179. "remark-math@5.1.1": ImportVar(
  180. tag=_REMARK_MATH._js_expr, is_default=True
  181. ),
  182. "remark-gfm@3.0.1": ImportVar(
  183. tag=_REMARK_GFM._js_expr, is_default=True
  184. ),
  185. "remark-unwrap-images@4.0.0": ImportVar(
  186. tag=_REMARK_UNWRAP_IMAGES._js_expr, is_default=True
  187. ),
  188. "rehype-katex@6.0.3": ImportVar(
  189. tag=_REHYPE_KATEX._js_expr, is_default=True
  190. ),
  191. "rehype-raw@6.1.1": ImportVar(
  192. tag=_REHYPE_RAW._js_expr, is_default=True
  193. ),
  194. },
  195. *[
  196. component(_MOCK_ARG)._get_all_imports()
  197. for component in self.component_map.values()
  198. ],
  199. *(
  200. [inline_code_var_data.old_school_imports()]
  201. if (
  202. inline_code_var_data
  203. := self._get_inline_code_fn_var()._get_all_var_data()
  204. )
  205. is not None
  206. else []
  207. ),
  208. ]
  209. def _get_tag_map_fn_var(self, tag: str) -> Var:
  210. return self._get_map_fn_var_from_children(self.get_component(tag), tag)
  211. def format_component_map(self) -> dict[str, Var]:
  212. """Format the component map for rendering.
  213. Returns:
  214. The formatted component map.
  215. """
  216. components = {
  217. tag: self._get_tag_map_fn_var(tag)
  218. for tag in self.component_map
  219. if tag not in ("code", "codeblock")
  220. }
  221. # Separate out inline code and code blocks.
  222. components["code"] = self._get_inline_code_fn_var()
  223. return components
  224. def _get_inline_code_fn_var(self) -> Var:
  225. """Get the function variable for inline code.
  226. This function creates a Var that represents a function to handle
  227. both inline code and code blocks in markdown.
  228. Returns:
  229. The Var for inline code.
  230. """
  231. # Get any custom code from the codeblock and code components.
  232. custom_code_list = self._get_map_fn_custom_code_from_children(
  233. self.get_component("codeblock")
  234. )
  235. custom_code_list.extend(
  236. self._get_map_fn_custom_code_from_children(self.get_component("code"))
  237. )
  238. var_data = VarData.merge(
  239. *[
  240. code._get_all_var_data()
  241. for code in custom_code_list
  242. if isinstance(code, Var)
  243. ]
  244. )
  245. codeblock_custom_code = "\n".join(map(str, custom_code_list))
  246. # Format the code to handle inline and block code.
  247. formatted_code = f"""
  248. const match = (className || '').match(/language-(?<lang>.*)/);
  249. let {_LANGUAGE!s} = match ? match[1] : '';
  250. {codeblock_custom_code};
  251. return inline ? (
  252. {self.format_component("code")}
  253. ) : (
  254. {self.format_component("codeblock", language=_LANGUAGE)}
  255. );
  256. """.replace("\n", " ")
  257. return MarkdownComponentMap.create_map_fn_var(
  258. fn_args=(
  259. "node",
  260. "inline",
  261. "className",
  262. _CHILDREN._js_expr,
  263. _PROPS._js_expr,
  264. ),
  265. fn_body=Var(_js_expr=formatted_code),
  266. explicit_return=True,
  267. var_data=var_data,
  268. )
  269. def get_component(self, tag: str, **props) -> Component:
  270. """Get the component for a tag and props.
  271. Args:
  272. tag: The tag of the component.
  273. **props: The props of the component.
  274. Returns:
  275. The component.
  276. Raises:
  277. ValueError: If the tag is invalid.
  278. """
  279. # Check the tag is valid.
  280. if tag not in self.component_map:
  281. raise ValueError(f"No markdown component found for tag: {tag}.")
  282. special_props = [_PROPS_IN_TAG]
  283. children = [
  284. _CHILDREN
  285. if tag != "codeblock"
  286. # For codeblock, the mapping for some cases returns an array of elements. Let's join them into a string.
  287. else ternary_operation(
  288. ARRAY_ISARRAY.call(_CHILDREN),
  289. _CHILDREN.to(list).join("\n"),
  290. _CHILDREN,
  291. ).to(str)
  292. ]
  293. # For certain tags, the props from the markdown renderer are not actually valid for the component.
  294. if tag in NO_PROPS_TAGS:
  295. special_props = []
  296. # If the children are set as a prop, don't pass them as children.
  297. children_prop = props.pop("children", None)
  298. if children_prop is not None:
  299. special_props.append(Var(_js_expr=f"children={{{children_prop!s}}}"))
  300. children = []
  301. # Get the component.
  302. component = self.component_map[tag](*children, **props).set(
  303. special_props=special_props
  304. )
  305. return component
  306. def format_component(self, tag: str, **props) -> str:
  307. """Format a component for rendering in the component map.
  308. Args:
  309. tag: The tag of the component.
  310. **props: Extra props to pass to the component function.
  311. Returns:
  312. The formatted component.
  313. """
  314. return str(self.get_component(tag, **props)).replace("\n", "")
  315. def _get_map_fn_var_from_children(self, component: Component, tag: str) -> Var:
  316. """Create a function Var for the component map for the specified tag.
  317. Args:
  318. component: The component to check for custom code.
  319. tag: The tag of the component.
  320. Returns:
  321. The function Var for the component map.
  322. """
  323. formatted_component = Var(
  324. _js_expr=f"({self.format_component(tag)})", _var_type=str
  325. )
  326. if isinstance(component, MarkdownComponentMap):
  327. return component.create_map_fn_var(fn_body=formatted_component)
  328. # fallback to the default fn Var creation if the component is not a MarkdownComponentMap.
  329. return MarkdownComponentMap.create_map_fn_var(fn_body=formatted_component)
  330. def _get_map_fn_custom_code_from_children(
  331. self, component: BaseComponent
  332. ) -> list[str | Var]:
  333. """Recursively get markdown custom code from children components.
  334. Args:
  335. component: The component to check for custom code.
  336. Returns:
  337. A list of markdown custom code strings.
  338. """
  339. custom_code_list: list[str | Var] = []
  340. if isinstance(component, MarkdownComponentMap):
  341. custom_code_list.append(component.get_component_map_custom_code())
  342. # If the component is a custom component(rx.memo), obtain the underlining
  343. # component and get the custom code from the children.
  344. if isinstance(component, CustomComponent):
  345. custom_code_list.extend(
  346. self._get_map_fn_custom_code_from_children(
  347. component.component_fn(*component.get_prop_vars())
  348. )
  349. )
  350. elif isinstance(component, Component):
  351. for child in component.children:
  352. custom_code_list.extend(
  353. self._get_map_fn_custom_code_from_children(child)
  354. )
  355. return custom_code_list
  356. @staticmethod
  357. def _component_map_hash(component_map: dict) -> str:
  358. inp = str(
  359. {tag: component(_MOCK_ARG) for tag, component in component_map.items()}
  360. ).encode()
  361. return md5(inp).hexdigest()
  362. def _get_component_map_name(self) -> str:
  363. return f"ComponentMap_{self.component_map_hash}"
  364. def _get_custom_code(self) -> str | None:
  365. hooks = {}
  366. from reflex.compiler.templates import MACROS
  367. for _component in self.component_map.values():
  368. comp = _component(_MOCK_ARG)
  369. hooks.update(comp._get_all_hooks())
  370. formatted_hooks = MACROS.module.renderHooks(hooks) # pyright: ignore [reportAttributeAccessIssue]
  371. return f"""
  372. function {self._get_component_map_name()} () {{
  373. {formatted_hooks}
  374. return (
  375. {LiteralVar.create(self.format_component_map())!s}
  376. )
  377. }}
  378. """
  379. def _render(self) -> Tag:
  380. tag = (
  381. super()
  382. ._render()
  383. .add_props(
  384. remark_plugins=_REMARK_PLUGINS,
  385. rehype_plugins=_REHYPE_PLUGINS,
  386. components=Var(_js_expr=f"{self._get_component_map_name()}()"),
  387. )
  388. .remove_props("componentMap", "componentMapHash")
  389. )
  390. return tag