瀏覽代碼

[REF-2977] [REF-2982] Merge layout prop and automatic darkmode (#3442)

* [REF-2977] [REF-2982] Merge layout prop and automatic darkmode

* Expose `template` prop in rx.plotly
* Set default `template` to a color_mode_cond that responds to dark mode
* Merge user-provided `layout` and `template` props into the serialized
  plot data

* pyi_generator: avoid affecting the HTML Template component

* Raise warning when importing rx.plotly without plotly installed

* Remove PlotlyLib component and consolidate imports [fixup]
Masen Furer 11 月之前
父節點
當前提交
6e719d4c83
共有 4 個文件被更改,包括 114 次插入97 次删除
  1. 76 12
      reflex/components/plotly/plotly.py
  2. 15 82
      reflex/components/plotly/plotly.pyi
  3. 7 2
      reflex/utils/pyi_generator.py
  4. 16 1
      reflex/utils/serializers.py

+ 76 - 12
reflex/components/plotly/plotly.py

@@ -4,14 +4,20 @@ from __future__ import annotations
 from typing import Any, Dict, List
 
 from reflex.base import Base
-from reflex.components.component import NoSSRComponent
+from reflex.components.component import Component, NoSSRComponent
+from reflex.components.core.cond import color_mode_cond
 from reflex.event import EventHandler
+from reflex.utils import console
 from reflex.vars import Var
 
 try:
-    from plotly.graph_objects import Figure
+    from plotly.graph_objects import Figure, layout
+
+    Template = layout.Template
 except ImportError:
+    console.warn("Plotly is not installed. Please run `pip install plotly`.")
     Figure = Any  # type: ignore
+    Template = Any  # type: ignore
 
 
 def _event_data_signature(e0: Var) -> List[Any]:
@@ -84,17 +90,13 @@ def _null_signature() -> List[Any]:
     return []
 
 
-class PlotlyLib(NoSSRComponent):
-    """A component that wraps a plotly lib."""
+class Plotly(NoSSRComponent):
+    """Display a plotly graph."""
 
     library = "react-plotly.js@2.6.0"
 
     lib_dependencies: List[str] = ["plotly.js@2.22.0"]
 
-
-class Plotly(PlotlyLib):
-    """Display a plotly graph."""
-
     tag = "Plot"
 
     is_default = True
@@ -105,6 +107,9 @@ class Plotly(PlotlyLib):
     # The layout of the graph.
     layout: Var[Dict]
 
+    # The template for visual appearance of the graph.
+    template: Var[Template]
+
     # The config of the graph.
     config: Var[Dict]
 
@@ -171,6 +176,17 @@ class Plotly(PlotlyLib):
     # Fired when a hovered element is no longer hovered.
     on_unhover: EventHandler[_event_points_data_signature]
 
+    def add_imports(self) -> dict[str, str]:
+        """Add imports for the plotly component.
+
+        Returns:
+            The imports for the plotly component.
+        """
+        return {
+            # For merging plotly data/layout/templates.
+            "mergician@v2.0.2": "mergician"
+        }
+
     def add_custom_code(self) -> list[str]:
         """Add custom codes for processing the plotly points data.
 
@@ -210,14 +226,62 @@ const extractPoints = (points) => {
 """,
         ]
 
+    @classmethod
+    def create(cls, *children, **props) -> Component:
+        """Create the Plotly component.
+
+        Args:
+            *children: The children of the component.
+            **props: The properties of the component.
+
+        Returns:
+            The Plotly component.
+        """
+        from plotly.io import templates
+
+        responsive_template = color_mode_cond(
+            light=Var.create_safe(templates["plotly"]).to(dict),
+            dark=Var.create_safe(templates["plotly_dark"]).to(dict),
+        )
+        if isinstance(responsive_template, Var):
+            # Mark the conditional Var as a Template to avoid type mismatch
+            responsive_template = responsive_template.to(Template)
+        props.setdefault("template", responsive_template)
+        return super().create(*children, **props)
+
+    def _exclude_props(self) -> set[str]:
+        # These props are handled specially in the _render function
+        return {"data", "layout", "template"}
+
     def _render(self):
         tag = super()._render()
         figure = self.data.to(dict)
-        if self.layout is None:
-            tag.remove_props("data", "layout")
+        merge_dicts = []  # Data will be merged and spread from these dict Vars
+        if self.layout is not None:
+            # Why is this not a literal dict? Great question... it didn't work
+            # reliably because of how _var_name_unwrapped strips the outer curly
+            # brackets if any of the contained Vars depend on state.
+            layout_dict = Var.create_safe(
+                f"{{'layout': {self.layout.to(dict)._var_name_unwrapped}}}"
+            ).to(dict)
+            merge_dicts.append(layout_dict)
+        if self.template is not None:
+            template_dict = Var.create_safe(
+                {"layout": {"template": self.template.to(dict)}}
+            )
+            template_dict._var_data = None  # To avoid stripping outer curly brackets
+            merge_dicts.append(template_dict)
+        if merge_dicts:
             tag.special_props.add(
-                Var.create_safe(f"{{...{figure._var_name_unwrapped}}}")
+                # Merge all dictionaries and spread the result over props.
+                Var.create_safe(
+                    f"{{...mergician({figure._var_name_unwrapped},"
+                    f"{','.join(md._var_name_unwrapped for md in merge_dicts)})}}",
+                ),
             )
         else:
-            tag.add_props(data=figure["data"])
+            # Spread the figure dict over props, nothing to merge.
+            tag.special_props.add(
+                Var.create_safe(f"{{...{figure._var_name_unwrapped}}}")
+            )
         return tag

+ 15 - 82
reflex/components/plotly/plotly.pyi

@@ -9,97 +9,28 @@ from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from typing import Any, Dict, List
 from reflex.base import Base
-from reflex.components.component import NoSSRComponent
+from reflex.components.component import Component, NoSSRComponent
+from reflex.components.core.cond import color_mode_cond
 from reflex.event import EventHandler
+from reflex.utils import console
 from reflex.vars import Var
 
 try:
-    from plotly.graph_objects import Figure  # type: ignore
+    from plotly.graph_objects import Figure, layout  # type: ignore
+
+    Template = layout.Template
 except ImportError:
+    console.warn("Plotly is not installed. Please run `pip install plotly`.")
     Figure = Any  # type: ignore
+    Template = Any
 
 class _ButtonClickData(Base):
     menu: Any
     button: Any
     active: Any
 
-class PlotlyLib(NoSSRComponent):
-    @overload
-    @classmethod
-    def create(  # type: ignore
-        cls,
-        *children,
-        style: Optional[Style] = None,
-        key: Optional[Any] = None,
-        id: Optional[Any] = None,
-        class_name: Optional[Any] = None,
-        autofocus: Optional[bool] = None,
-        custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
-        on_blur: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_click: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_context_menu: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_double_click: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_focus: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mount: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_down: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_enter: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_leave: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_move: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_out: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_over: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_mouse_up: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_scroll: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        on_unmount: Optional[
-            Union[EventHandler, EventSpec, list, function, BaseVar]
-        ] = None,
-        **props
-    ) -> "PlotlyLib":
-        """Create the component.
-
-        Args:
-            *children: The children of the component.
-            style: The style of the component.
-            key: A unique key for the component.
-            id: The id for the component.
-            class_name: The class name for the component.
-            autofocus: Whether the component should take the focus once the page is loaded
-            custom_attrs: custom attribute
-            **props: The props of the component.
-
-        Returns:
-            The component.
-        """
-        ...
-
-class Plotly(PlotlyLib):
+class Plotly(NoSSRComponent):
+    def add_imports(self) -> dict[str, str]: ...
     def add_custom_code(self) -> list[str]: ...
     @overload
     @classmethod
@@ -108,6 +39,7 @@ class Plotly(PlotlyLib):
         *children,
         data: Optional[Union[Var[Figure], Figure]] = None,  # type: ignore
         layout: Optional[Union[Var[Dict], Dict]] = None,
+        template: Optional[Union[Var[Template], Template]] = None,  # type: ignore
         config: Optional[Union[Var[Dict], Dict]] = None,
         use_resize_handler: Optional[Union[Var[bool], bool]] = None,
         style: Optional[Style] = None,
@@ -217,12 +149,13 @@ class Plotly(PlotlyLib):
         ] = None,
         **props
     ) -> "Plotly":
-        """Create the component.
+        """Create the Plotly component.
 
         Args:
             *children: The children of the component.
             data: The figure to display. This can be a plotly figure or a plotly data json.
             layout: The layout of the graph.
+            template: The template for visual appearance of the graph.
             config: The config of the graph.
             use_resize_handler: If true, the graph will resize when the window is resized.
             style: The style of the component.
@@ -231,9 +164,9 @@ class Plotly(PlotlyLib):
             class_name: The class name for the component.
             autofocus: Whether the component should take the focus once the page is loaded
             custom_attrs: custom attribute
-            **props: The props of the component.
+            **props: The properties of the component.
 
         Returns:
-            The component.
+            The Plotly component.
         """
         ...

+ 7 - 2
reflex/utils/pyi_generator.py

@@ -32,7 +32,7 @@ logger = logging.getLogger("pyi_generator")
 PWD = Path(".").resolve()
 
 EXCLUDED_FILES = [
-    # "app.py",
+    "app.py",
     "component.py",
     "bare.py",
     "foreach.py",
@@ -856,7 +856,11 @@ class PyiGenerator:
                 mode=black.mode.Mode(is_pyi=True),
             ).splitlines():
                 # Bit of a hack here, since the AST cannot represent comments.
-                if "def create(" in formatted_line or "Figure" in formatted_line:
+                if (
+                    "def create(" in formatted_line
+                    or "Figure" in formatted_line
+                    or "Var[Template]" in formatted_line
+                ):
                     pyi_content.append(formatted_line + "  # type: ignore")
                 else:
                     pyi_content.append(formatted_line)
@@ -956,6 +960,7 @@ class PyiGenerator:
                 target_path.is_file()
                 and target_path.suffix == ".py"
                 and target_path.name not in EXCLUDED_FILES
+                and "reflex/components" in str(target_path)
             ):
                 file_targets.append(target_path)
                 continue

+ 16 - 1
reflex/utils/serializers.py

@@ -314,7 +314,7 @@ except ImportError:
     pass
 
 try:
-    from plotly.graph_objects import Figure
+    from plotly.graph_objects import Figure, layout
     from plotly.io import to_json
 
     @serializer
@@ -329,6 +329,21 @@ try:
         """
         return json.loads(str(to_json(figure)))
 
+    @serializer
+    def serialize_template(template: layout.Template) -> dict:
+        """Serialize a plotly template.
+
+        Args:
+            template: The template to serialize.
+
+        Returns:
+            The serialized template.
+        """
+        return {
+            "data": json.loads(str(to_json(template.data))),
+            "layout": json.loads(str(to_json(template.layout))),
+        }
+
 except ImportError:
     pass