Преглед изворни кода

Get `action`, `cancel`, `on_dismiss` and `on_auto_close` working for rx.toast (#3216)

* Get `action` and `cancel` working for rx.toast

Respect defaults set in ToastProvider toast_options when firing a toast with
it's own ToastProps set.

* Update reflex/components/sonner/toast.py

Co-authored-by: Thomas Brandého <thomas.brandeho@gmail.com>

* Move queueEvent formatting into rx.utils.format module

Implement on_auto_close and on_dismiss callbacks inside ToastProps

* Update rx.call_script to use new format.format_queue_events

Replace duplicate logic in rx.call_script for handling the callback function.

* Move PropsBase to reflex.components.props

This base class will be exposed via rx._x.PropsBase and can be shared by other
wrapped components that need to pass a JS object full of extra props.

---------

Co-authored-by: Thomas Brandého <thomas.brandeho@gmail.com>
Masen Furer пре 1 година
родитељ
комит
76c8b2dfbd

+ 30 - 0
reflex/components/props.py

@@ -0,0 +1,30 @@
+"""A class that holds props to be passed or applied to a component."""
+from __future__ import annotations
+
+from reflex.base import Base
+from reflex.utils import format
+from reflex.utils.serializers import serialize
+
+
+class PropsBase(Base):
+    """Base for a class containing props that can be serialized as a JS object."""
+
+    def json(self) -> str:
+        """Convert the object to a json-like string.
+
+        Vars will be unwrapped so they can represent actual JS var names and functions.
+
+        Keys will be converted to camelCase.
+
+        Returns:
+            The object as a Javascript Object literal.
+        """
+        return format.unwrap_vars(
+            self.__config__.json_dumps(
+                {
+                    format.to_camel_case(key): value
+                    for key, value in self.dict().items()
+                },
+                default=serialize,
+            )
+        )

+ 92 - 28
reflex/components/sonner/toast.py

@@ -2,16 +2,20 @@
 
 from __future__ import annotations
 
-from typing import Literal
+from typing import Any, Literal, Optional
 
 from reflex.base import Base
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.lucide.icon import Icon
-from reflex.event import EventSpec, call_script
+from reflex.components.props import PropsBase
+from reflex.event import (
+    EventSpec,
+    call_script,
+)
 from reflex.style import Style, color_mode
 from reflex.utils import format
 from reflex.utils.imports import ImportVar
-from reflex.utils.serializers import serialize
+from reflex.utils.serializers import serialize, serializer
 from reflex.vars import Var, VarData
 
 LiteralPosition = Literal[
@@ -27,46 +31,68 @@ LiteralPosition = Literal[
 toast_ref = Var.create_safe("refs['__toast']")
 
 
-class PropsBase(Base):
-    """Base class for all props classes."""
+class ToastAction(Base):
+    """A toast action that render a button in the toast."""
 
-    def json(self) -> str:
-        """Convert the object to a json string.
+    label: str
+    on_click: Any
 
-        Returns:
-            The object as a json string.
-        """
-        from reflex.utils.serializers import serialize
 
-        return self.__config__.json_dumps(
-            {format.to_camel_case(key): value for key, value in self.dict().items()},
-            default=serialize,
+@serializer
+def serialize_action(action: ToastAction) -> dict:
+    """Serialize a toast action.
+
+    Args:
+        action: The toast action to serialize.
+
+    Returns:
+        The serialized toast action with on_click formatted to queue the given event.
+    """
+    return {
+        "label": action.label,
+        "onClick": format.format_queue_events(action.on_click),
+    }
+
+
+def _toast_callback_signature(toast: Var) -> list[Var]:
+    """The signature for the toast callback, stripping out unserializable keys.
+
+    Args:
+        toast: The toast variable.
+
+    Returns:
+        A function call stripping non-serializable members of the toast object.
+    """
+    return [
+        Var.create_safe(
+            f"(() => {{let {{action, cancel, onDismiss, onAutoClose, ...rest}} = {toast}; return rest}})()"
         )
+    ]
 
 
 class ToastProps(PropsBase):
     """Props for the toast component."""
 
     # Toast's description, renders underneath the title.
-    description: str = ""
+    description: Optional[str]
 
     # Whether to show the close button.
-    close_button: bool = False
+    close_button: Optional[bool]
 
     # Dark toast in light mode and vice versa.
-    invert: bool = False
+    invert: Optional[bool]
 
     # Control the sensitivity of the toast for screen readers
-    important: bool = False
+    important: Optional[bool]
 
     # Time in milliseconds that should elapse before automatically closing the toast.
-    duration: int = 4000
+    duration: Optional[int]
 
     # Position of the toast.
-    position: LiteralPosition = "bottom-right"
+    position: Optional[LiteralPosition]
 
     # If false, it'll prevent the user from dismissing the toast.
-    dismissible: bool = True
+    dismissible: Optional[bool]
 
     # TODO: fix serialization of icons for toast? (might not be possible yet)
     # Icon displayed in front of toast's text, aligned vertically.
@@ -74,25 +100,63 @@ class ToastProps(PropsBase):
 
     # TODO: fix implementation for action / cancel buttons
     # Renders a primary button, clicking it will close the toast.
-    # action: str = ""
+    action: Optional[ToastAction]
 
     # Renders a secondary button, clicking it will close the toast.
-    # cancel: str = ""
+    cancel: Optional[ToastAction]
 
     # Custom id for the toast.
-    id: str = ""
+    id: Optional[str]
 
     # Removes the default styling, which allows for easier customization.
-    unstyled: bool = False
+    unstyled: Optional[bool]
 
     # Custom style for the toast.
-    style: Style = Style()
+    style: Optional[Style]
 
+    # XXX: These still do not seem to work
     # Custom style for the toast primary button.
-    # action_button_styles: Style = Style()
+    action_button_styles: Optional[Style]
 
     # Custom style for the toast secondary button.
-    # cancel_button_styles: Style = Style()
+    cancel_button_styles: Optional[Style]
+
+    # The function gets called when either the close button is clicked, or the toast is swiped.
+    on_dismiss: Optional[Any]
+
+    # Function that gets called when the toast disappears automatically after it's timeout (duration` prop).
+    on_auto_close: Optional[Any]
+
+    def dict(self, *args, **kwargs) -> dict:
+        """Convert the object to a dictionary.
+
+        Args:
+            *args: The arguments to pass to the base class.
+            **kwargs: The keyword arguments to pass to the base
+
+        Returns:
+            The object as a dictionary with ToastAction fields intact.
+        """
+        kwargs.setdefault("exclude_none", True)
+        d = super().dict(*args, **kwargs)
+        # Keep these fields as ToastAction so they can be serialized specially
+        if "action" in d:
+            d["action"] = self.action
+            if isinstance(self.action, dict):
+                d["action"] = ToastAction(**self.action)
+        if "cancel" in d:
+            d["cancel"] = self.cancel
+            if isinstance(self.cancel, dict):
+                d["cancel"] = ToastAction(**self.cancel)
+        if "on_dismiss" in d:
+            d["on_dismiss"] = format.format_queue_events(
+                self.on_dismiss, _toast_callback_signature
+            )
+        if "on_auto_close" in d:
+            d["on_auto_close"] = format.format_queue_events(
+                self.on_auto_close, _toast_callback_signature
+            )
+        return d
 
 
 class Toaster(Component):

+ 27 - 14
reflex/components/sonner/toast.pyi

@@ -7,15 +7,16 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
-from typing import Literal
+from typing import Any, Literal, Optional
 from reflex.base import Base
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.lucide.icon import Icon
+from reflex.components.props import PropsBase
 from reflex.event import EventSpec, call_script
 from reflex.style import Style, color_mode
 from reflex.utils import format
 from reflex.utils.imports import ImportVar
-from reflex.utils.serializers import serialize
+from reflex.utils.serializers import serialize, serializer
 from reflex.vars import Var, VarData
 
 LiteralPosition = Literal[
@@ -28,20 +29,32 @@ LiteralPosition = Literal[
 ]
 toast_ref = Var.create_safe("refs['__toast']")
 
-class PropsBase(Base):
-    def json(self) -> str: ...
+class ToastAction(Base):
+    label: str
+    on_click: Any
+
+@serializer
+def serialize_action(action: ToastAction) -> dict: ...
 
 class ToastProps(PropsBase):
-    description: str
-    close_button: bool
-    invert: bool
-    important: bool
-    duration: int
-    position: LiteralPosition
-    dismissible: bool
-    id: str
-    unstyled: bool
-    style: Style
+    description: Optional[str]
+    close_button: Optional[bool]
+    invert: Optional[bool]
+    important: Optional[bool]
+    duration: Optional[int]
+    position: Optional[LiteralPosition]
+    dismissible: Optional[bool]
+    action: Optional[ToastAction]
+    cancel: Optional[ToastAction]
+    id: Optional[str]
+    unstyled: Optional[bool]
+    style: Optional[Style]
+    action_button_styles: Optional[Style]
+    cancel_button_styles: Optional[Style]
+    on_dismiss: Optional[Any]
+    on_auto_close: Optional[Any]
+
+    def dict(self, *args, **kwargs) -> dict: ...
 
 class Toaster(Component):
     @staticmethod

+ 9 - 13
reflex/event.py

@@ -4,7 +4,6 @@ from __future__ import annotations
 
 import inspect
 from base64 import b64encode
-from types import FunctionType
 from typing import (
     Any,
     Callable,
@@ -706,7 +705,11 @@ def _callback_arg_spec(eval_result):
 
 def call_script(
     javascript_code: str,
-    callback: EventHandler | Callable | None = None,
+    callback: EventSpec
+    | EventHandler
+    | Callable
+    | List[EventSpec | EventHandler | Callable]
+    | None = None,
 ) -> EventSpec:
     """Create an event handler that executes arbitrary javascript code.
 
@@ -716,21 +719,14 @@ def call_script(
 
     Returns:
         EventSpec: An event that will execute the client side javascript.
-
-    Raises:
-        ValueError: If the callback is not a valid event handler.
     """
     callback_kwargs = {}
     if callback is not None:
-        arg_name = parse_args_spec(_callback_arg_spec)[0]._var_name
-        if isinstance(callback, EventHandler):
-            event_spec = call_event_handler(callback, _callback_arg_spec)
-        elif isinstance(callback, FunctionType):
-            event_spec = call_event_fn(callback, _callback_arg_spec)[0]
-        else:
-            raise ValueError("Cannot use {callback!r} as a call_script callback.")
         callback_kwargs = {
-            "callback": f"({arg_name}) => queueEvents([{format.format_event(event_spec)}], {constants.CompileVars.SOCKET})"
+            "callback": format.format_queue_events(
+                callback,
+                args_spec=lambda result: [result],
+            )
         }
     return server_side(
         "_call_script",

+ 2 - 0
reflex/experimental/__init__.py

@@ -2,6 +2,7 @@
 
 from types import SimpleNamespace
 
+from reflex.components.props import PropsBase
 from reflex.components.radix.themes.components.progress import progress as progress
 from reflex.components.sonner.toast import toast as toast
 
@@ -18,6 +19,7 @@ _x = SimpleNamespace(
     hooks=hooks,
     layout=layout,
     progress=progress,
+    PropsBase=PropsBase,
     run_in_thread=run_in_thread,
     toast=toast,
 )

+ 73 - 2
reflex/utils/format.py

@@ -6,7 +6,7 @@ import inspect
 import json
 import os
 import re
-from typing import TYPE_CHECKING, Any, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
 
 from reflex import constants
 from reflex.utils import exceptions, serializers, types
@@ -15,7 +15,7 @@ from reflex.vars import BaseVar, Var
 
 if TYPE_CHECKING:
     from reflex.components.component import ComponentStyle
-    from reflex.event import EventChain, EventHandler, EventSpec
+    from reflex.event import ArgsSpec, EventChain, EventHandler, EventSpec
 
 WRAP_MAP = {
     "{": "}",
@@ -590,6 +590,77 @@ def format_event_chain(
     )
 
 
+def format_queue_events(
+    events: EventSpec
+    | EventHandler
+    | Callable
+    | List[EventSpec | EventHandler | Callable]
+    | None = None,
+    args_spec: Optional[ArgsSpec] = None,
+) -> Var[EventChain]:
+    """Format a list of event handler / event spec as a javascript callback.
+
+    The resulting code can be passed to interfaces that expect a callback
+    function and when triggered it will directly call queueEvents.
+
+    It is intended to be executed in the rx.call_script context, where some
+    existing API needs a callback to trigger a backend event handler.
+
+    Args:
+        events: The events to queue.
+        args_spec: The argument spec for the callback.
+
+    Returns:
+        The compiled javascript callback to queue the given events on the frontend.
+    """
+    from reflex.event import (
+        EventChain,
+        EventHandler,
+        EventSpec,
+        call_event_fn,
+        call_event_handler,
+    )
+
+    if not events:
+        return Var.create_safe(
+            "() => null", _var_is_string=False, _var_is_local=False
+        ).to(EventChain)
+
+    # If no spec is provided, the function will take no arguments.
+    def _default_args_spec():
+        return []
+
+    # Construct the arguments that the function accepts.
+    sig = inspect.signature(args_spec or _default_args_spec)  # type: ignore
+    if sig.parameters:
+        arg_def = ",".join(f"_{p}" for p in sig.parameters)
+        arg_def = f"({arg_def})"
+    else:
+        arg_def = "()"
+
+    payloads = []
+    if not isinstance(events, list):
+        events = [events]
+
+    # Process each event/spec/lambda (similar to Component._create_event_chain).
+    for spec in events:
+        specs: list[EventSpec] = []
+        if isinstance(spec, (EventHandler, EventSpec)):
+            specs = [call_event_handler(spec, args_spec or _default_args_spec)]
+        elif isinstance(spec, type(lambda: None)):
+            specs = call_event_fn(spec, args_spec or _default_args_spec)
+        payloads.extend(format_event(s) for s in specs)
+
+    # Return the final code snippet, expecting queueEvents, processEvent, and socket to be in scope.
+    # Typically this snippet will _only_ run from within an rx.call_script eval context.
+    return Var.create_safe(
+        f"{arg_def} => {{queueEvents([{','.join(payloads)}], {constants.CompileVars.SOCKET}); "
+        f"processEvent({constants.CompileVars.SOCKET})}}",
+        _var_is_string=False,
+        _var_is_local=False,
+    ).to(EventChain)
+
+
 def format_query_params(router_data: dict[str, Any]) -> dict[str, str]:
     """Convert back query params name to python-friendly case.