Browse Source

[ENG-3892]Shiki codeblock support decorations (#4234)

* Shiki codeblock support decorations

* add decorations to useEffect

* fix pyright

* validate decorations dict

* Fix exception message plus unit tests

* possible test fix

* fix pyright

* possible tests fix

* cast decorations before creating codeblock

* `plain` is not a valid theme

* pyi fix

* address PR comment
Elijah Ahianyo 6 months ago
parent
commit
a968231750

+ 16 - 11
reflex/.templates/web/components/shiki/code.js

@@ -1,26 +1,31 @@
 import { useEffect, useState } from "react"
 import { useEffect, useState } from "react"
 import { codeToHtml} from "shiki"
 import { codeToHtml} from "shiki"
 
 
-export function Code ({code, theme, language, transformers, ...divProps}) {
+/**
+ * Code component that uses Shiki to convert code to HTML and render it.
+ *
+ * @param code - The code to be highlighted.
+ * @param theme - The theme to be used for highlighting.
+ * @param language - The language of the code.
+ * @param transformers - The transformers to be applied to the code.
+ * @param decorations - The decorations to be applied to the code.
+ * @param divProps - Additional properties to be passed to the div element.
+ * @returns The rendered code block.
+ */
+export function Code ({code, theme, language, transformers, decorations, ...divProps}) {
     const [codeResult, setCodeResult] = useState("")
     const [codeResult, setCodeResult] = useState("")
     useEffect(() => {
     useEffect(() => {
         async function fetchCode() {
         async function fetchCode() {
-          let final_code;
-
-          if (Array.isArray(code)) {
-            final_code = code[0];
-          } else {
-            final_code = code;
-          }
-          const result = await codeToHtml(final_code, {
+          const result = await codeToHtml(code, {
             lang: language,
             lang: language,
             theme,
             theme,
-            transformers
+            transformers,
+            decorations
           });
           });
           setCodeResult(result);
           setCodeResult(result);
         }
         }
         fetchCode();
         fetchCode();
-      }, [code, language, theme, transformers]
+      }, [code, language, theme, transformers, decorations]
 
 
     )
     )
     return (
     return (

+ 36 - 3
reflex/components/datadisplay/shiki_code_block.py

@@ -12,6 +12,7 @@ from reflex.components.core.colors import color
 from reflex.components.core.cond import color_mode_cond
 from reflex.components.core.cond import color_mode_cond
 from reflex.components.el.elements.forms import Button
 from reflex.components.el.elements.forms import Button
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
+from reflex.components.props import NoExtrasAllowedProps
 from reflex.components.radix.themes.layout.box import Box
 from reflex.components.radix.themes.layout.box import Box
 from reflex.event import call_script, set_clipboard
 from reflex.event import call_script, set_clipboard
 from reflex.style import Style
 from reflex.style import Style
@@ -253,6 +254,7 @@ LiteralCodeLanguage = Literal[
     "pascal",
     "pascal",
     "perl",
     "perl",
     "php",
     "php",
+    "plain",
     "plsql",
     "plsql",
     "po",
     "po",
     "postcss",
     "postcss",
@@ -369,10 +371,11 @@ LiteralCodeTheme = Literal[
     "nord",
     "nord",
     "one-dark-pro",
     "one-dark-pro",
     "one-light",
     "one-light",
-    "plain",
     "plastic",
     "plastic",
     "poimandres",
     "poimandres",
     "red",
     "red",
+    # rose-pine themes dont work with the current version of shikijs transformers
+    # https://github.com/shikijs/shiki/issues/730
     "rose-pine",
     "rose-pine",
     "rose-pine-dawn",
     "rose-pine-dawn",
     "rose-pine-moon",
     "rose-pine-moon",
@@ -390,6 +393,23 @@ LiteralCodeTheme = Literal[
 ]
 ]
 
 
 
 
+class Position(NoExtrasAllowedProps):
+    """Position of the decoration."""
+
+    line: int
+    character: int
+
+
+class ShikiDecorations(NoExtrasAllowedProps):
+    """Decorations for the code block."""
+
+    start: Union[int, Position]
+    end: Union[int, Position]
+    tag_name: str = "span"
+    properties: dict[str, Any] = {}
+    always_wrap: bool = False
+
+
 class ShikiBaseTransformers(Base):
 class ShikiBaseTransformers(Base):
     """Base for creating transformers."""
     """Base for creating transformers."""
 
 
@@ -537,6 +557,9 @@ class ShikiCodeBlock(Component):
         []
         []
     )
     )
 
 
+    # The decorations to use for the syntax highlighter.
+    decorations: Var[list[ShikiDecorations]] = Var.create([])
+
     @classmethod
     @classmethod
     def create(
     def create(
         cls,
         cls,
@@ -555,6 +578,7 @@ class ShikiCodeBlock(Component):
         # Separate props for the code block and the wrapper
         # Separate props for the code block and the wrapper
         code_block_props = {}
         code_block_props = {}
         code_wrapper_props = {}
         code_wrapper_props = {}
+        decorations = props.pop("decorations", [])
 
 
         class_props = cls.get_props()
         class_props = cls.get_props()
 
 
@@ -564,6 +588,15 @@ class ShikiCodeBlock(Component):
                 value
                 value
             )
             )
 
 
+        # cast decorations into ShikiDecorations.
+        decorations = [
+            ShikiDecorations(**decoration)
+            if not isinstance(decoration, ShikiDecorations)
+            else decoration
+            for decoration in decorations
+        ]
+        code_block_props["decorations"] = decorations
+
         code_block_props["code"] = children[0]
         code_block_props["code"] = children[0]
         code_block = super().create(**code_block_props)
         code_block = super().create(**code_block_props)
 
 
@@ -676,10 +709,10 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
     show_line_numbers: Var[bool]
     show_line_numbers: Var[bool]
 
 
     # Whether a copy button should appear.
     # Whether a copy button should appear.
-    can_copy: Var[bool] = Var.create(False)
+    can_copy: bool = False
 
 
     # copy_button: A custom copy button to override the default one.
     # copy_button: A custom copy button to override the default one.
-    copy_button: Var[Optional[Union[Component, bool]]] = Var.create(None)
+    copy_button: Optional[Union[Component, bool]] = None
 
 
     @classmethod
     @classmethod
     def create(
     def create(

+ 35 - 15
reflex/components/datadisplay/shiki_code_block.pyi

@@ -7,6 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
+from reflex.components.props import NoExtrasAllowedProps
 from reflex.event import EventType
 from reflex.event import EventType
 from reflex.style import Style
 from reflex.style import Style
 from reflex.vars.base import Var
 from reflex.vars.base import Var
@@ -192,6 +193,7 @@ LiteralCodeLanguage = Literal[
     "pascal",
     "pascal",
     "perl",
     "perl",
     "php",
     "php",
+    "plain",
     "plsql",
     "plsql",
     "po",
     "po",
     "postcss",
     "postcss",
@@ -308,7 +310,6 @@ LiteralCodeTheme = Literal[
     "nord",
     "nord",
     "one-dark-pro",
     "one-dark-pro",
     "one-light",
     "one-light",
-    "plain",
     "plastic",
     "plastic",
     "poimandres",
     "poimandres",
     "red",
     "red",
@@ -328,6 +329,17 @@ LiteralCodeTheme = Literal[
     "vitesse-light",
     "vitesse-light",
 ]
 ]
 
 
+class Position(NoExtrasAllowedProps):
+    line: int
+    character: int
+
+class ShikiDecorations(NoExtrasAllowedProps):
+    start: Union[int, Position]
+    end: Union[int, Position]
+    tag_name: str
+    properties: dict[str, Any]
+    always_wrap: bool
+
 class ShikiBaseTransformers(Base):
 class ShikiBaseTransformers(Base):
     library: str
     library: str
     fns: list[FunctionStringVar]
     fns: list[FunctionStringVar]
@@ -479,6 +491,7 @@ class ShikiCodeBlock(Component):
                     "pascal",
                     "pascal",
                     "perl",
                     "perl",
                     "php",
                     "php",
+                    "plain",
                     "plsql",
                     "plsql",
                     "po",
                     "po",
                     "postcss",
                     "postcss",
@@ -694,6 +707,7 @@ class ShikiCodeBlock(Component):
                         "pascal",
                         "pascal",
                         "perl",
                         "perl",
                         "php",
                         "php",
+                        "plain",
                         "plsql",
                         "plsql",
                         "po",
                         "po",
                         "postcss",
                         "postcss",
@@ -815,7 +829,6 @@ class ShikiCodeBlock(Component):
                     "nord",
                     "nord",
                     "one-dark-pro",
                     "one-dark-pro",
                     "one-light",
                     "one-light",
-                    "plain",
                     "plastic",
                     "plastic",
                     "poimandres",
                     "poimandres",
                     "red",
                     "red",
@@ -870,7 +883,6 @@ class ShikiCodeBlock(Component):
                         "nord",
                         "nord",
                         "one-dark-pro",
                         "one-dark-pro",
                         "one-light",
                         "one-light",
-                        "plain",
                         "plastic",
                         "plastic",
                         "poimandres",
                         "poimandres",
                         "red",
                         "red",
@@ -906,6 +918,9 @@ class ShikiCodeBlock(Component):
                 list[Union[ShikiBaseTransformers, dict[str, Any]]],
                 list[Union[ShikiBaseTransformers, dict[str, Any]]],
             ]
             ]
         ] = None,
         ] = None,
+        decorations: Optional[
+            Union[Var[list[ShikiDecorations]], list[ShikiDecorations]]
+        ] = None,
         style: Optional[Style] = None,
         style: Optional[Style] = None,
         key: Optional[Any] = None,
         key: Optional[Any] = None,
         id: Optional[Any] = None,
         id: Optional[Any] = None,
@@ -938,6 +953,7 @@ class ShikiCodeBlock(Component):
             themes: The set of themes to use for different modes.
             themes: The set of themes to use for different modes.
             code: The code to display.
             code: The code to display.
             transformers: The transformers to use for the syntax highlighter.
             transformers: The transformers to use for the syntax highlighter.
+            decorations: The decorations to use for the syntax highlighter.
             style: The style of the component.
             style: The style of the component.
             key: A unique key for the component.
             key: A unique key for the component.
             id: The id for the component.
             id: The id for the component.
@@ -965,10 +981,8 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
         *children,
         *children,
         use_transformers: Optional[Union[Var[bool], bool]] = None,
         use_transformers: Optional[Union[Var[bool], bool]] = None,
         show_line_numbers: Optional[Union[Var[bool], bool]] = None,
         show_line_numbers: Optional[Union[Var[bool], bool]] = None,
-        can_copy: Optional[Union[Var[bool], bool]] = None,
-        copy_button: Optional[
-            Union[Component, Var[Optional[Union[Component, bool]]], bool]
-        ] = None,
+        can_copy: Optional[bool] = None,
+        copy_button: Optional[Union[Component, bool]] = None,
         language: Optional[
         language: Optional[
             Union[
             Union[
                 Literal[
                 Literal[
@@ -1104,6 +1118,7 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
                     "pascal",
                     "pascal",
                     "perl",
                     "perl",
                     "php",
                     "php",
+                    "plain",
                     "plsql",
                     "plsql",
                     "po",
                     "po",
                     "postcss",
                     "postcss",
@@ -1319,6 +1334,7 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
                         "pascal",
                         "pascal",
                         "perl",
                         "perl",
                         "php",
                         "php",
+                        "plain",
                         "plsql",
                         "plsql",
                         "po",
                         "po",
                         "postcss",
                         "postcss",
@@ -1440,7 +1456,6 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
                     "nord",
                     "nord",
                     "one-dark-pro",
                     "one-dark-pro",
                     "one-light",
                     "one-light",
-                    "plain",
                     "plastic",
                     "plastic",
                     "poimandres",
                     "poimandres",
                     "red",
                     "red",
@@ -1495,7 +1510,6 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
                         "nord",
                         "nord",
                         "one-dark-pro",
                         "one-dark-pro",
                         "one-light",
                         "one-light",
-                        "plain",
                         "plastic",
                         "plastic",
                         "poimandres",
                         "poimandres",
                         "red",
                         "red",
@@ -1531,6 +1545,9 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
                 list[Union[ShikiBaseTransformers, dict[str, Any]]],
                 list[Union[ShikiBaseTransformers, dict[str, Any]]],
             ]
             ]
         ] = None,
         ] = None,
+        decorations: Optional[
+            Union[Var[list[ShikiDecorations]], list[ShikiDecorations]]
+        ] = None,
         style: Optional[Style] = None,
         style: Optional[Style] = None,
         key: Optional[Any] = None,
         key: Optional[Any] = None,
         id: Optional[Any] = None,
         id: Optional[Any] = None,
@@ -1567,6 +1584,7 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
             themes: The set of themes to use for different modes.
             themes: The set of themes to use for different modes.
             code: The code to display.
             code: The code to display.
             transformers: The transformers to use for the syntax highlighter.
             transformers: The transformers to use for the syntax highlighter.
+            decorations: The decorations to use for the syntax highlighter.
             style: The style of the component.
             style: The style of the component.
             key: A unique key for the component.
             key: A unique key for the component.
             id: The id for the component.
             id: The id for the component.
@@ -1593,10 +1611,8 @@ class CodeblockNamespace(ComponentNamespace):
         *children,
         *children,
         use_transformers: Optional[Union[Var[bool], bool]] = None,
         use_transformers: Optional[Union[Var[bool], bool]] = None,
         show_line_numbers: Optional[Union[Var[bool], bool]] = None,
         show_line_numbers: Optional[Union[Var[bool], bool]] = None,
-        can_copy: Optional[Union[Var[bool], bool]] = None,
-        copy_button: Optional[
-            Union[Component, Var[Optional[Union[Component, bool]]], bool]
-        ] = None,
+        can_copy: Optional[bool] = None,
+        copy_button: Optional[Union[Component, bool]] = None,
         language: Optional[
         language: Optional[
             Union[
             Union[
                 Literal[
                 Literal[
@@ -1732,6 +1748,7 @@ class CodeblockNamespace(ComponentNamespace):
                     "pascal",
                     "pascal",
                     "perl",
                     "perl",
                     "php",
                     "php",
+                    "plain",
                     "plsql",
                     "plsql",
                     "po",
                     "po",
                     "postcss",
                     "postcss",
@@ -1947,6 +1964,7 @@ class CodeblockNamespace(ComponentNamespace):
                         "pascal",
                         "pascal",
                         "perl",
                         "perl",
                         "php",
                         "php",
+                        "plain",
                         "plsql",
                         "plsql",
                         "po",
                         "po",
                         "postcss",
                         "postcss",
@@ -2068,7 +2086,6 @@ class CodeblockNamespace(ComponentNamespace):
                     "nord",
                     "nord",
                     "one-dark-pro",
                     "one-dark-pro",
                     "one-light",
                     "one-light",
-                    "plain",
                     "plastic",
                     "plastic",
                     "poimandres",
                     "poimandres",
                     "red",
                     "red",
@@ -2123,7 +2140,6 @@ class CodeblockNamespace(ComponentNamespace):
                         "nord",
                         "nord",
                         "one-dark-pro",
                         "one-dark-pro",
                         "one-light",
                         "one-light",
-                        "plain",
                         "plastic",
                         "plastic",
                         "poimandres",
                         "poimandres",
                         "red",
                         "red",
@@ -2159,6 +2175,9 @@ class CodeblockNamespace(ComponentNamespace):
                 list[Union[ShikiBaseTransformers, dict[str, Any]]],
                 list[Union[ShikiBaseTransformers, dict[str, Any]]],
             ]
             ]
         ] = None,
         ] = None,
+        decorations: Optional[
+            Union[Var[list[ShikiDecorations]], list[ShikiDecorations]]
+        ] = None,
         style: Optional[Style] = None,
         style: Optional[Style] = None,
         key: Optional[Any] = None,
         key: Optional[Any] = None,
         id: Optional[Any] = None,
         id: Optional[Any] = None,
@@ -2195,6 +2214,7 @@ class CodeblockNamespace(ComponentNamespace):
             themes: The set of themes to use for different modes.
             themes: The set of themes to use for different modes.
             code: The code to display.
             code: The code to display.
             transformers: The transformers to use for the syntax highlighter.
             transformers: The transformers to use for the syntax highlighter.
+            decorations: The decorations to use for the syntax highlighter.
             style: The style of the component.
             style: The style of the component.
             key: A unique key for the component.
             key: A unique key for the component.
             id: The id for the component.
             id: The id for the component.

+ 34 - 0
reflex/components/props.py

@@ -2,8 +2,11 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+from pydantic import ValidationError
+
 from reflex.base import Base
 from reflex.base import Base
 from reflex.utils import format
 from reflex.utils import format
+from reflex.utils.exceptions import InvalidPropValueError
 from reflex.vars.object import LiteralObjectVar
 from reflex.vars.object import LiteralObjectVar
 
 
 
 
@@ -40,3 +43,34 @@ class PropsBase(Base):
             format.to_camel_case(key): value
             format.to_camel_case(key): value
             for key, value in super().dict(*args, **kwargs).items()
             for key, value in super().dict(*args, **kwargs).items()
         }
         }
+
+
+class NoExtrasAllowedProps(Base):
+    """A class that holds props to be passed or applied to a component with no extra props allowed."""
+
+    def __init__(self, component_name=None, **kwargs):
+        """Initialize the props.
+
+        Args:
+            component_name: The custom name of the component.
+            kwargs: Kwargs to initialize the props.
+
+        Raises:
+            InvalidPropValueError: If invalid props are passed on instantiation.
+        """
+        component_name = component_name or type(self).__name__
+        try:
+            super().__init__(**kwargs)
+        except ValidationError as e:
+            invalid_fields = ", ".join([error["loc"][0] for error in e.errors()])  # type: ignore
+            supported_props_str = ", ".join(f'"{field}"' for field in self.get_fields())
+            raise InvalidPropValueError(
+                f"Invalid prop(s) {invalid_fields} for {component_name!r}. Supported props are {supported_props_str}"
+            ) from None
+
+    class Config:
+        """Pydantic config."""
+
+        arbitrary_types_allowed = True
+        use_enum_values = True
+        extra = "forbid"

+ 3 - 30
reflex/components/sonner/toast.py

@@ -4,12 +4,10 @@ from __future__ import annotations
 
 
 from typing import Any, ClassVar, Literal, Optional, Union
 from typing import Any, ClassVar, Literal, Optional, Union
 
 
-from pydantic import ValidationError
-
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
-from reflex.components.props import PropsBase
+from reflex.components.props import NoExtrasAllowedProps, PropsBase
 from reflex.event import (
 from reflex.event import (
     EventSpec,
     EventSpec,
     call_script,
     call_script,
@@ -72,7 +70,7 @@ def _toast_callback_signature(toast: Var) -> list[Var]:
     ]
     ]
 
 
 
 
-class ToastProps(PropsBase):
+class ToastProps(PropsBase, NoExtrasAllowedProps):
     """Props for the toast component."""
     """Props for the toast component."""
 
 
     # Toast's title, renders above the description.
     # Toast's title, renders above the description.
@@ -132,24 +130,6 @@ class ToastProps(PropsBase):
     # Function that gets called when the toast disappears automatically after it's timeout (duration` prop).
     # Function that gets called when the toast disappears automatically after it's timeout (duration` prop).
     on_auto_close: Optional[Any]
     on_auto_close: Optional[Any]
 
 
-    def __init__(self, **kwargs):
-        """Initialize the props.
-
-        Args:
-            kwargs: Kwargs to initialize the props.
-
-        Raises:
-            ValueError: If invalid props are passed on instantiation.
-        """
-        try:
-            super().__init__(**kwargs)
-        except ValidationError as e:
-            invalid_fields = ", ".join([error["loc"][0] for error in e.errors()])  # type: ignore
-            supported_props_str = ", ".join(f'"{field}"' for field in self.get_fields())
-            raise ValueError(
-                f"Invalid prop(s) {invalid_fields} for rx.toast. Supported props are {supported_props_str}"
-            ) from None
-
     def dict(self, *args, **kwargs) -> dict[str, Any]:
     def dict(self, *args, **kwargs) -> dict[str, Any]:
         """Convert the object to a dictionary.
         """Convert the object to a dictionary.
 
 
@@ -181,13 +161,6 @@ class ToastProps(PropsBase):
             )
             )
         return d
         return d
 
 
-    class Config:
-        """Pydantic config."""
-
-        arbitrary_types_allowed = True
-        use_enum_values = True
-        extra = "forbid"
-
 
 
 class Toaster(Component):
 class Toaster(Component):
     """A Toaster Component for displaying toast notifications."""
     """A Toaster Component for displaying toast notifications."""
@@ -281,7 +254,7 @@ class Toaster(Component):
         if message == "" and ("title" not in props or "description" not in props):
         if message == "" and ("title" not in props or "description" not in props):
             raise ValueError("Toast message or title or description must be provided.")
             raise ValueError("Toast message or title or description must be provided.")
         if props:
         if props:
-            args = LiteralVar.create(ToastProps(**props))
+            args = LiteralVar.create(ToastProps(component_name="rx.toast", **props))  # type: ignore
             toast = f"{toast_command}(`{message}`, {str(args)})"
             toast = f"{toast_command}(`{message}`, {str(args)})"
         else:
         else:
             toast = f"{toast_command}(`{message}`)"
             toast = f"{toast_command}(`{message}`)"

+ 2 - 7
reflex/components/sonner/toast.pyi

@@ -8,7 +8,7 @@ from typing import Any, ClassVar, Dict, Literal, Optional, Union, overload
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
-from reflex.components.props import PropsBase
+from reflex.components.props import NoExtrasAllowedProps, PropsBase
 from reflex.event import EventSpec, EventType
 from reflex.event import EventSpec, EventType
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils.serializers import serializer
 from reflex.utils.serializers import serializer
@@ -31,7 +31,7 @@ class ToastAction(Base):
 @serializer
 @serializer
 def serialize_action(action: ToastAction) -> dict: ...
 def serialize_action(action: ToastAction) -> dict: ...
 
 
-class ToastProps(PropsBase):
+class ToastProps(PropsBase, NoExtrasAllowedProps):
     title: Optional[Union[str, Var]]
     title: Optional[Union[str, Var]]
     description: Optional[Union[str, Var]]
     description: Optional[Union[str, Var]]
     close_button: Optional[bool]
     close_button: Optional[bool]
@@ -52,11 +52,6 @@ class ToastProps(PropsBase):
 
 
     def dict(self, *args, **kwargs) -> dict[str, Any]: ...
     def dict(self, *args, **kwargs) -> dict[str, Any]: ...
 
 
-    class Config:
-        arbitrary_types_allowed = True
-        use_enum_values = True
-        extra = "forbid"
-
 class Toaster(Component):
 class Toaster(Component):
     is_used: ClassVar[bool] = False
     is_used: ClassVar[bool] = False
 
 

+ 4 - 0
reflex/utils/exceptions.py

@@ -143,3 +143,7 @@ class EnvironmentVarValueError(ReflexError, ValueError):
 
 
 class DynamicComponentInvalidSignature(ReflexError, TypeError):
 class DynamicComponentInvalidSignature(ReflexError, TypeError):
     """Raised when a dynamic component has an invalid signature."""
     """Raised when a dynamic component has an invalid signature."""
+
+
+class InvalidPropValueError(ReflexError):
+    """Raised when a prop value is invalid."""

+ 63 - 0
tests/units/components/test_props.py

@@ -0,0 +1,63 @@
+import pytest
+
+from reflex.components.props import NoExtrasAllowedProps
+from reflex.utils.exceptions import InvalidPropValueError
+
+try:
+    from pydantic.v1 import ValidationError
+except ModuleNotFoundError:
+    from pydantic import ValidationError
+
+
+class PropA(NoExtrasAllowedProps):
+    """Base prop class."""
+
+    foo: str
+    bar: str
+
+
+class PropB(NoExtrasAllowedProps):
+    """Prop class with nested props."""
+
+    foobar: str
+    foobaz: PropA
+
+
+@pytest.mark.parametrize(
+    "props_class, kwargs, should_raise",
+    [
+        (PropA, {"foo": "value", "bar": "another_value"}, False),
+        (PropA, {"fooz": "value", "bar": "another_value"}, True),
+        (
+            PropB,
+            {
+                "foobaz": {"foo": "value", "bar": "another_value"},
+                "foobar": "foo_bar_value",
+            },
+            False,
+        ),
+        (
+            PropB,
+            {
+                "fooba": {"foo": "value", "bar": "another_value"},
+                "foobar": "foo_bar_value",
+            },
+            True,
+        ),
+        (
+            PropB,
+            {
+                "foobaz": {"foobar": "value", "bar": "another_value"},
+                "foobar": "foo_bar_value",
+            },
+            True,
+        ),
+    ],
+)
+def test_no_extras_allowed_props(props_class, kwargs, should_raise):
+    if should_raise:
+        with pytest.raises((ValidationError, InvalidPropValueError)):
+            props_class(**kwargs)
+    else:
+        props_instance = props_class(**kwargs)
+        assert isinstance(props_instance, props_class)