markdown.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. """Markdown component."""
  2. from __future__ import annotations
  3. import textwrap
  4. from typing import Any, Callable, Dict, Union
  5. from reflex.compiler import utils
  6. from reflex.components.component import Component
  7. from reflex.components.datadisplay.list import ListItem, OrderedList, UnorderedList
  8. from reflex.components.navigation import Link
  9. from reflex.components.tags.tag import Tag
  10. from reflex.components.typography.heading import Heading
  11. from reflex.components.typography.text import Text
  12. from reflex.style import Style
  13. from reflex.utils import console, imports, types
  14. from reflex.vars import ImportVar, Var
  15. # Special vars used in the component map.
  16. _CHILDREN = Var.create_safe("children", is_local=False)
  17. _PROPS = Var.create_safe("...props", is_local=False)
  18. # Special remark plugins.
  19. _REMARK_MATH = Var.create_safe("remarkMath", is_local=False)
  20. _REMARK_GFM = Var.create_safe("remarkGfm", is_local=False)
  21. _REMARK_PLUGINS = Var.create_safe([_REMARK_MATH, _REMARK_GFM])
  22. # Special rehype plugins.
  23. _REHYPE_KATEX = Var.create_safe("rehypeKatex", is_local=False)
  24. _REHYPE_RAW = Var.create_safe("rehypeRaw", is_local=False)
  25. _REHYPE_PLUGINS = Var.create_safe([_REHYPE_KATEX, _REHYPE_RAW])
  26. # Component Mapping
  27. def get_base_component_map() -> dict[str, Callable]:
  28. """Get the base component map.
  29. Returns:
  30. The base component map.
  31. """
  32. from reflex.components.datadisplay.code import Code, CodeBlock
  33. return {
  34. "h1": lambda value: Heading.create(
  35. value, as_="h1", size="2xl", margin_y="0.5em"
  36. ),
  37. "h2": lambda value: Heading.create(
  38. value, as_="h2", size="xl", margin_y="0.5em"
  39. ),
  40. "h3": lambda value: Heading.create(
  41. value, as_="h3", size="lg", margin_y="0.5em"
  42. ),
  43. "h4": lambda value: Heading.create(
  44. value, as_="h4", size="md", margin_y="0.5em"
  45. ),
  46. "h5": lambda value: Heading.create(
  47. value, as_="h5", size="sm", margin_y="0.5em"
  48. ),
  49. "h6": lambda value: Heading.create(
  50. value, as_="h6", size="xs", margin_y="0.5em"
  51. ),
  52. "p": lambda value: Text.create(value, margin_y="1em"),
  53. "ul": lambda value: UnorderedList.create(value, margin_y="1em"), # type: ignore
  54. "ol": lambda value: OrderedList.create(value, margin_y="1em"), # type: ignore
  55. "li": lambda value: ListItem.create(value),
  56. "a": lambda value: Link.create(value),
  57. "code": lambda value: Code.create(value),
  58. "codeblock": lambda *_, **props: CodeBlock.create(
  59. theme="light", margin_y="1em", **props
  60. ),
  61. }
  62. class Markdown(Component):
  63. """A markdown component."""
  64. library = "react-markdown@8.0.7"
  65. tag = "ReactMarkdown"
  66. is_default = True
  67. # The component map from a tag to a lambda that creates a component.
  68. component_map: Dict[str, Any] = {}
  69. # Custom styles for the markdown (deprecated in v0.2.9).
  70. custom_styles: Dict[str, Any] = {}
  71. @classmethod
  72. def create(cls, *children, **props) -> Component:
  73. """Create a markdown component.
  74. Args:
  75. *children: The children of the component.
  76. **props: The properties of the component.
  77. Returns:
  78. The markdown component.
  79. """
  80. assert len(children) == 1 and types._isinstance(
  81. children[0], Union[str, Var]
  82. ), "Markdown component must have exactly one child containing the markdown source."
  83. # Custom styles are deprecated.
  84. if "custom_styles" in props:
  85. console.deprecate(
  86. "rx.markdown custom_styles",
  87. "Use the component_map prop instead.",
  88. "0.2.9",
  89. "0.3.1",
  90. )
  91. # Update the base component map with the custom component map.
  92. component_map = {**get_base_component_map(), **props.pop("component_map", {})}
  93. # Get the markdown source.
  94. src = children[0]
  95. # Dedent the source.
  96. if isinstance(src, str):
  97. src = textwrap.dedent(src)
  98. # Create the component.
  99. return super().create(src, component_map=component_map, **props)
  100. def _get_imports(self) -> imports.ImportDict:
  101. # Import here to avoid circular imports.
  102. from reflex.components.datadisplay.code import Code, CodeBlock
  103. imports = super()._get_imports()
  104. # Special markdown imports.
  105. imports.update(
  106. {
  107. "": {ImportVar(tag="katex/dist/katex.min.css")},
  108. "remark-math@5.1.1": {
  109. ImportVar(tag=_REMARK_MATH.name, is_default=True)
  110. },
  111. "remark-gfm@3.0.1": {ImportVar(tag=_REMARK_GFM.name, is_default=True)},
  112. "rehype-katex@6.0.3": {
  113. ImportVar(tag=_REHYPE_KATEX.name, is_default=True)
  114. },
  115. "rehype-raw@6.1.1": {ImportVar(tag=_REHYPE_RAW.name, is_default=True)},
  116. }
  117. )
  118. # Get the imports for each component.
  119. for component in self.component_map.values():
  120. imports = utils.merge_imports(
  121. imports, component(Var.create("")).get_imports()
  122. )
  123. # Get the imports for the code components.
  124. imports = utils.merge_imports(
  125. imports, CodeBlock.create(theme="light")._get_imports()
  126. )
  127. imports = utils.merge_imports(imports, Code.create()._get_imports())
  128. return imports
  129. def get_component(self, tag: str, **props) -> Component:
  130. """Get the component for a tag and props.
  131. Args:
  132. tag: The tag of the component.
  133. **props: The props of the component.
  134. Returns:
  135. The component.
  136. Raises:
  137. ValueError: If the tag is invalid.
  138. """
  139. # Check the tag is valid.
  140. if tag not in self.component_map:
  141. raise ValueError(f"No markdown component found for tag: {tag}.")
  142. special_props = {_PROPS}
  143. children = [_CHILDREN]
  144. # If the children are set as a prop, don't pass them as children.
  145. children_prop = props.pop("children", None)
  146. if children_prop is not None:
  147. special_props.add(Var.create_safe(f"children={str(children_prop)}"))
  148. children = []
  149. # Get the component.
  150. component = self.component_map[tag](*children, **props).set(
  151. special_props=special_props
  152. )
  153. component._add_style(Style(self.custom_styles.get(tag, {})))
  154. return component
  155. def format_component(self, tag: str, **props) -> str:
  156. """Format a component for rendering in the component map.
  157. Args:
  158. tag: The tag of the component.
  159. **props: Extra props to pass to the component function.
  160. Returns:
  161. The formatted component.
  162. """
  163. return str(self.get_component(tag, **props)).replace("\n", " ")
  164. def format_component_map(self) -> dict[str, str]:
  165. """Format the component map for rendering.
  166. Returns:
  167. The formatted component map.
  168. """
  169. components = {
  170. tag: f"{{({{{_CHILDREN.name}, {_PROPS.name}}}) => {self.format_component(tag)}}}"
  171. for tag in self.component_map
  172. }
  173. # Separate out inline code and code blocks.
  174. components[
  175. "code"
  176. ] = f"""{{({{inline, className, {_CHILDREN.name}, {_PROPS.name}}}) => {{
  177. const match = (className || '').match(/language-(?<lang>.*)/);
  178. const language = match ? match[1] : '';
  179. return inline ? (
  180. {self.format_component("code")}
  181. ) : (
  182. {self.format_component("codeblock", language=Var.create_safe("language", is_local=False), children=Var.create_safe("String(children)", is_local=False))}
  183. );
  184. }}}}""".replace(
  185. "\n", " "
  186. )
  187. return components
  188. def _render(self) -> Tag:
  189. return (
  190. super()
  191. ._render()
  192. .add_props(
  193. components=self.format_component_map(),
  194. remark_plugins=_REMARK_PLUGINS,
  195. rehype_plugins=_REHYPE_PLUGINS,
  196. )
  197. .remove_props("componentMap")
  198. )