Преглед на файлове

[REF-2977] [REF-2982] Merge layout prop and automatic darkmode (#3442)

* [REF-2977] [REF-2982] Merge layout prop and automatic darkmode

* Expose `template` prop in rx.plotly
* Set default `template` to a color_mode_cond that responds to dark mode
* Merge user-provided `layout` and `template` props into the serialized
  plot data

* pyi_generator: avoid affecting the HTML Template component

* Raise warning when importing rx.plotly without plotly installed

* Remove PlotlyLib component and consolidate imports [fixup]
Masen Furer преди 11 месеца
родител
ревизия
6e719d4c83
променени са 4 файла, в които са добавени 114 реда и са изтрити 97 реда
  1. 76 12
      reflex/components/plotly/plotly.py
  2. 15 82
      reflex/components/plotly/plotly.pyi
  3. 7 2
      reflex/utils/pyi_generator.py
  4. 16 1
      reflex/utils/serializers.py

+ 76 - 12
reflex/components/plotly/plotly.py

@@ -4,14 +4,20 @@ from __future__ import annotations
 from typing import Any, Dict, List
 from typing import Any, Dict, List
 
 
 from reflex.base import Base
 from reflex.base import Base
-from reflex.components.component import NoSSRComponent
+from reflex.components.component import Component, NoSSRComponent
+from reflex.components.core.cond import color_mode_cond
 from reflex.event import EventHandler
 from reflex.event import EventHandler
+from reflex.utils import console
 from reflex.vars import Var
 from reflex.vars import Var
 
 
 try:
 try:
-    from plotly.graph_objects import Figure
+    from plotly.graph_objects import Figure, layout
+
+    Template = layout.Template
 except ImportError:
 except ImportError:
+    console.warn("Plotly is not installed. Please run `pip install plotly`.")
     Figure = Any  # type: ignore
     Figure = Any  # type: ignore
+    Template = Any  # type: ignore
 
 
 
 
 def _event_data_signature(e0: Var) -> List[Any]:
 def _event_data_signature(e0: Var) -> List[Any]:
@@ -84,17 +90,13 @@ def _null_signature() -> List[Any]:
     return []
     return []
 
 
 
 
-class PlotlyLib(NoSSRComponent):
-    """A component that wraps a plotly lib."""
+class Plotly(NoSSRComponent):
+    """Display a plotly graph."""
 
 
     library = "react-plotly.js@2.6.0"
     library = "react-plotly.js@2.6.0"
 
 
     lib_dependencies: List[str] = ["plotly.js@2.22.0"]
     lib_dependencies: List[str] = ["plotly.js@2.22.0"]
 
 
-
-class Plotly(PlotlyLib):
-    """Display a plotly graph."""
-
     tag = "Plot"
     tag = "Plot"
 
 
     is_default = True
     is_default = True
@@ -105,6 +107,9 @@ class Plotly(PlotlyLib):
     # The layout of the graph.
     # The layout of the graph.
     layout: Var[Dict]
     layout: Var[Dict]
 
 
+    # The template for visual appearance of the graph.
+    template: Var[Template]
+
     # The config of the graph.
     # The config of the graph.
     config: Var[Dict]
     config: Var[Dict]
 
 
@@ -171,6 +176,17 @@ class Plotly(PlotlyLib):
     # Fired when a hovered element is no longer hovered.
     # Fired when a hovered element is no longer hovered.
     on_unhover: EventHandler[_event_points_data_signature]
     on_unhover: EventHandler[_event_points_data_signature]
 
 
+    def add_imports(self) -> dict[str, str]:
+        """Add imports for the plotly component.
+
+        Returns:
+            The imports for the plotly component.
+        """
+        return {
+            # For merging plotly data/layout/templates.
+            "mergician@v2.0.2": "mergician"
+        }
+
     def add_custom_code(self) -> list[str]:
     def add_custom_code(self) -> list[str]:
         """Add custom codes for processing the plotly points data.
         """Add custom codes for processing the plotly points data.
 
 
@@ -210,14 +226,62 @@ const extractPoints = (points) => {
 """,
 """,
         ]
         ]
 
 
+    @classmethod
+    def create(cls, *children, **props) -> Component:
+        """Create the Plotly component.
+
+        Args:
+            *children: The children of the component.
+            **props: The properties of the component.
+
+        Returns:
+            The Plotly component.
+        """
+        from plotly.io import templates
+
+        responsive_template = color_mode_cond(
+            light=Var.create_safe(templates["plotly"]).to(dict),
+            dark=Var.create_safe(templates["plotly_dark"]).to(dict),
+        )
+        if isinstance(responsive_template, Var):
+            # Mark the conditional Var as a Template to avoid type mismatch
+            responsive_template = responsive_template.to(Template)
+        props.setdefault("template", responsive_template)
+        return super().create(*children, **props)
+
+    def _exclude_props(self) -> set[str]:
+        # These props are handled specially in the _render function
+        return {"data", "layout", "template"}
+
     def _render(self):
     def _render(self):
         tag = super()._render()
         tag = super()._render()
         figure = self.data.to(dict)
         figure = self.data.to(dict)
-        if self.layout is None:
-            tag.remove_props("data", "layout")
+        merge_dicts = []  # Data will be merged and spread from these dict Vars
+        if self.layout is not None:
+            # Why is this not a literal dict? Great question... it didn't work
+            # reliably because of how _var_name_unwrapped strips the outer curly
+            # brackets if any of the contained Vars depend on state.
+            layout_dict = Var.create_safe(
+                f"{{'layout': {self.layout.to(dict)._var_name_unwrapped}}}"
+            ).to(dict)
+            merge_dicts.append(layout_dict)
+        if self.template is not None:
+            template_dict = Var.create_safe(
+                {"layout": {"template": self.template.to(dict)}}
+            )
+            template_dict._var_data = None  # To avoid stripping outer curly brackets
+            merge_dicts.append(template_dict)
+        if merge_dicts:
             tag.special_props.add(
             tag.special_props.add(
-                Var.create_safe(f"{{...{figure._var_name_unwrapped}}}")
+                # Merge all dictionaries and spread the result over props.
+                Var.create_safe(
+                    f"{{...mergician({figure._var_name_unwrapped},"
+                    f"{','.join(md._var_name_unwrapped for md in merge_dicts)})}}",
+                ),
             )
             )
         else:
         else:
-            tag.add_props(data=figure["data"])
+            # Spread the figure dict over props, nothing to merge.
+            tag.special_props.add(
+                Var.create_safe(f"{{...{figure._var_name_unwrapped}}}")
+            )
         return tag
         return tag

+ 15 - 82
reflex/components/plotly/plotly.pyi

@@ -9,97 +9,28 @@ from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from reflex.style import Style
 from typing import Any, Dict, List
 from typing import Any, Dict, List
 from reflex.base import Base
 from reflex.base import Base
-from reflex.components.component import NoSSRComponent
+from reflex.components.component import Component, NoSSRComponent
+from reflex.components.core.cond import color_mode_cond
 from reflex.event import EventHandler
 from reflex.event import EventHandler
+from reflex.utils import console
 from reflex.vars import Var
 from reflex.vars import Var
 
 
 try:
 try:
-    from plotly.graph_objects import Figure  # type: ignore
+    from plotly.graph_objects import Figure, layout  # type: ignore
+
+    Template = layout.Template
 except ImportError:
 except ImportError:
+    console.warn("Plotly is not installed. Please run `pip install plotly`.")
     Figure = Any  # type: ignore
     Figure = Any  # type: ignore
+    Template = Any
 
 
 class _ButtonClickData(Base):
 class _ButtonClickData(Base):
     menu: Any
     menu: Any
     button: Any
     button: Any
     active: Any
     active: Any
 
 
-class PlotlyLib(NoSSRComponent):
-    @overload
-    @classmethod
-    def create(  # type: ignore
-        cls,
-        *children,
-        style: Optional[Style] = None,
-        key: Optional[Any] = None,
-        id: Optional[Any] = None,
-        class_name: Optional[Any] = None,
-        autofocus: Optional[bool] = None,
-        custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
-        on_blur: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_click: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_context_menu: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_double_click: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_focus: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mount: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_down: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_enter: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_leave: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_move: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_out: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_over: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_up: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_scroll: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_unmount: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        **props
-    ) -> "PlotlyLib":
-        """Create the component.
-
-        Args:
-            *children: The children of the component.
-            style: The style of the component.
-            key: A unique key for the component.
-            id: The id for the component.
-            class_name: The class name for the component.
-            autofocus: Whether the component should take the focus once the page is loaded
-            custom_attrs: custom attribute
-            **props: The props of the component.
-
-        Returns:
-            The component.
-        """
-        ...
-
-class Plotly(PlotlyLib):
+class Plotly(NoSSRComponent):
+    def add_imports(self) -> dict[str, str]: ...
     def add_custom_code(self) -> list[str]: ...
     def add_custom_code(self) -> list[str]: ...
     @overload
     @overload
     @classmethod
     @classmethod
@@ -108,6 +39,7 @@ class Plotly(PlotlyLib):
         *children,
         *children,
         data: Optional[Union[Var[Figure], Figure]] = None,  # type: ignore
         data: Optional[Union[Var[Figure], Figure]] = None,  # type: ignore
         layout: Optional[Union[Var[Dict], Dict]] = None,
         layout: Optional[Union[Var[Dict], Dict]] = None,
+        template: Optional[Union[Var[Template], Template]] = None,  # type: ignore
         config: Optional[Union[Var[Dict], Dict]] = None,
         config: Optional[Union[Var[Dict], Dict]] = None,
         use_resize_handler: Optional[Union[Var[bool], bool]] = None,
         use_resize_handler: Optional[Union[Var[bool], bool]] = None,
         style: Optional[Style] = None,
         style: Optional[Style] = None,
@@ -217,12 +149,13 @@ class Plotly(PlotlyLib):
         ] = None,
         ] = None,
         **props
         **props
     ) -> "Plotly":
     ) -> "Plotly":
-        """Create the component.
+        """Create the Plotly component.
 
 
         Args:
         Args:
             *children: The children of the component.
             *children: The children of the component.
             data: The figure to display. This can be a plotly figure or a plotly data json.
             data: The figure to display. This can be a plotly figure or a plotly data json.
             layout: The layout of the graph.
             layout: The layout of the graph.
+            template: The template for visual appearance of the graph.
             config: The config of the graph.
             config: The config of the graph.
             use_resize_handler: If true, the graph will resize when the window is resized.
             use_resize_handler: If true, the graph will resize when the window is resized.
             style: The style of the component.
             style: The style of the component.
@@ -231,9 +164,9 @@ class Plotly(PlotlyLib):
             class_name: The class name for the component.
             class_name: The class name for the component.
             autofocus: Whether the component should take the focus once the page is loaded
             autofocus: Whether the component should take the focus once the page is loaded
             custom_attrs: custom attribute
             custom_attrs: custom attribute
-            **props: The props of the component.
+            **props: The properties of the component.
 
 
         Returns:
         Returns:
-            The component.
+            The Plotly component.
         """
         """
         ...
         ...

+ 7 - 2
reflex/utils/pyi_generator.py

@@ -32,7 +32,7 @@ logger = logging.getLogger("pyi_generator")
 PWD = Path(".").resolve()
 PWD = Path(".").resolve()
 
 
 EXCLUDED_FILES = [
 EXCLUDED_FILES = [
-    # "app.py",
+    "app.py",
     "component.py",
     "component.py",
     "bare.py",
     "bare.py",
     "foreach.py",
     "foreach.py",
@@ -856,7 +856,11 @@ class PyiGenerator:
                 mode=black.mode.Mode(is_pyi=True),
                 mode=black.mode.Mode(is_pyi=True),
             ).splitlines():
             ).splitlines():
                 # Bit of a hack here, since the AST cannot represent comments.
                 # Bit of a hack here, since the AST cannot represent comments.
-                if "def create(" in formatted_line or "Figure" in formatted_line:
+                if (
+                    "def create(" in formatted_line
+                    or "Figure" in formatted_line
+                    or "Var[Template]" in formatted_line
+                ):
                     pyi_content.append(formatted_line + "  # type: ignore")
                     pyi_content.append(formatted_line + "  # type: ignore")
                 else:
                 else:
                     pyi_content.append(formatted_line)
                     pyi_content.append(formatted_line)
@@ -956,6 +960,7 @@ class PyiGenerator:
                 target_path.is_file()
                 target_path.is_file()
                 and target_path.suffix == ".py"
                 and target_path.suffix == ".py"
                 and target_path.name not in EXCLUDED_FILES
                 and target_path.name not in EXCLUDED_FILES
+                and "reflex/components" in str(target_path)
             ):
             ):
                 file_targets.append(target_path)
                 file_targets.append(target_path)
                 continue
                 continue

+ 16 - 1
reflex/utils/serializers.py

@@ -314,7 +314,7 @@ except ImportError:
     pass
     pass
 
 
 try:
 try:
-    from plotly.graph_objects import Figure
+    from plotly.graph_objects import Figure, layout
     from plotly.io import to_json
     from plotly.io import to_json
 
 
     @serializer
     @serializer
@@ -329,6 +329,21 @@ try:
         """
         """
         return json.loads(str(to_json(figure)))
         return json.loads(str(to_json(figure)))
 
 
+    @serializer
+    def serialize_template(template: layout.Template) -> dict:
+        """Serialize a plotly template.
+
+        Args:
+            template: The template to serialize.
+
+        Returns:
+            The serialized template.
+        """
+        return {
+            "data": json.loads(str(to_json(template.data))),
+            "layout": json.loads(str(to_json(template.layout))),
+        }
+
 except ImportError:
 except ImportError:
     pass
     pass