Browse Source

Get `action`, `cancel`, `on_dismiss` and `on_auto_close` working for rx.toast (#3216)

* Get `action` and `cancel` working for rx.toast

Respect defaults set in ToastProvider toast_options when firing a toast with
it's own ToastProps set.

* Update reflex/components/sonner/toast.py

Co-authored-by: Thomas Brandého <thomas.brandeho@gmail.com>

* Move queueEvent formatting into rx.utils.format module

Implement on_auto_close and on_dismiss callbacks inside ToastProps

* Update rx.call_script to use new format.format_queue_events

Replace duplicate logic in rx.call_script for handling the callback function.

* Move PropsBase to reflex.components.props

This base class will be exposed via rx._x.PropsBase and can be shared by other
wrapped components that need to pass a JS object full of extra props.

---------

Co-authored-by: Thomas Brandého <thomas.brandeho@gmail.com>
Masen Furer 1 year ago
parent
commit
76c8b2dfbd

+ 30 - 0
reflex/components/props.py

@@ -0,0 +1,30 @@
+"""A class that holds props to be passed or applied to a component."""
+from __future__ import annotations
+
+from reflex.base import Base
+from reflex.utils import format
+from reflex.utils.serializers import serialize
+
+
+class PropsBase(Base):
+    """Base for a class containing props that can be serialized as a JS object."""
+
+    def json(self) -> str:
+        """Convert the object to a json-like string.
+
+        Vars will be unwrapped so they can represent actual JS var names and functions.
+
+        Keys will be converted to camelCase.
+
+        Returns:
+            The object as a Javascript Object literal.
+        """
+        return format.unwrap_vars(
+            self.__config__.json_dumps(
+                {
+                    format.to_camel_case(key): value
+                    for key, value in self.dict().items()
+                },
+                default=serialize,
+            )
+        )

+ 92 - 28
reflex/components/sonner/toast.py

@@ -2,16 +2,20 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Literal
+from typing import Any, Literal, Optional
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
-from reflex.event import EventSpec, call_script
+from reflex.components.props import PropsBase
+from reflex.event import (
+    EventSpec,
+    call_script,
+)
 from reflex.style import Style, color_mode
 from reflex.style import Style, color_mode
 from reflex.utils import format
 from reflex.utils import format
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
-from reflex.utils.serializers import serialize
+from reflex.utils.serializers import serialize, serializer
 from reflex.vars import Var, VarData
 from reflex.vars import Var, VarData
 
 
 LiteralPosition = Literal[
 LiteralPosition = Literal[
@@ -27,46 +31,68 @@ LiteralPosition = Literal[
 toast_ref = Var.create_safe("refs['__toast']")
 toast_ref = Var.create_safe("refs['__toast']")
 
 
 
 
-class PropsBase(Base):
-    """Base class for all props classes."""
+class ToastAction(Base):
+    """A toast action that render a button in the toast."""
 
 
-    def json(self) -> str:
-        """Convert the object to a json string.
+    label: str
+    on_click: Any
 
 
-        Returns:
-            The object as a json string.
-        """
-        from reflex.utils.serializers import serialize
 
 
-        return self.__config__.json_dumps(
-            {format.to_camel_case(key): value for key, value in self.dict().items()},
-            default=serialize,
+@serializer
+def serialize_action(action: ToastAction) -> dict:
+    """Serialize a toast action.
+
+    Args:
+        action: The toast action to serialize.
+
+    Returns:
+        The serialized toast action with on_click formatted to queue the given event.
+    """
+    return {
+        "label": action.label,
+        "onClick": format.format_queue_events(action.on_click),
+    }
+
+
+def _toast_callback_signature(toast: Var) -> list[Var]:
+    """The signature for the toast callback, stripping out unserializable keys.
+
+    Args:
+        toast: The toast variable.
+
+    Returns:
+        A function call stripping non-serializable members of the toast object.
+    """
+    return [
+        Var.create_safe(
+            f"(() => {{let {{action, cancel, onDismiss, onAutoClose, ...rest}} = {toast}; return rest}})()"
         )
         )
+    ]
 
 
 
 
 class ToastProps(PropsBase):
 class ToastProps(PropsBase):
     """Props for the toast component."""
     """Props for the toast component."""
 
 
     # Toast's description, renders underneath the title.
     # Toast's description, renders underneath the title.
-    description: str = ""
+    description: Optional[str]
 
 
     # Whether to show the close button.
     # Whether to show the close button.
-    close_button: bool = False
+    close_button: Optional[bool]
 
 
     # Dark toast in light mode and vice versa.
     # Dark toast in light mode and vice versa.
-    invert: bool = False
+    invert: Optional[bool]
 
 
     # Control the sensitivity of the toast for screen readers
     # Control the sensitivity of the toast for screen readers
-    important: bool = False
+    important: Optional[bool]
 
 
     # Time in milliseconds that should elapse before automatically closing the toast.
     # Time in milliseconds that should elapse before automatically closing the toast.
-    duration: int = 4000
+    duration: Optional[int]
 
 
     # Position of the toast.
     # Position of the toast.
-    position: LiteralPosition = "bottom-right"
+    position: Optional[LiteralPosition]
 
 
     # If false, it'll prevent the user from dismissing the toast.
     # If false, it'll prevent the user from dismissing the toast.
-    dismissible: bool = True
+    dismissible: Optional[bool]
 
 
     # TODO: fix serialization of icons for toast? (might not be possible yet)
     # TODO: fix serialization of icons for toast? (might not be possible yet)
     # Icon displayed in front of toast's text, aligned vertically.
     # Icon displayed in front of toast's text, aligned vertically.
@@ -74,25 +100,63 @@ class ToastProps(PropsBase):
 
 
     # TODO: fix implementation for action / cancel buttons
     # TODO: fix implementation for action / cancel buttons
     # Renders a primary button, clicking it will close the toast.
     # Renders a primary button, clicking it will close the toast.
-    # action: str = ""
+    action: Optional[ToastAction]
 
 
     # Renders a secondary button, clicking it will close the toast.
     # Renders a secondary button, clicking it will close the toast.
-    # cancel: str = ""
+    cancel: Optional[ToastAction]
 
 
     # Custom id for the toast.
     # Custom id for the toast.
-    id: str = ""
+    id: Optional[str]
 
 
     # Removes the default styling, which allows for easier customization.
     # Removes the default styling, which allows for easier customization.
-    unstyled: bool = False
+    unstyled: Optional[bool]
 
 
     # Custom style for the toast.
     # Custom style for the toast.
-    style: Style = Style()
+    style: Optional[Style]
 
 
+    # XXX: These still do not seem to work
     # Custom style for the toast primary button.
     # Custom style for the toast primary button.
-    # action_button_styles: Style = Style()
+    action_button_styles: Optional[Style]
 
 
     # Custom style for the toast secondary button.
     # Custom style for the toast secondary button.
-    # cancel_button_styles: Style = Style()
+    cancel_button_styles: Optional[Style]
+
+    # The function gets called when either the close button is clicked, or the toast is swiped.
+    on_dismiss: Optional[Any]
+
+    # Function that gets called when the toast disappears automatically after it's timeout (duration` prop).
+    on_auto_close: Optional[Any]
+
+    def dict(self, *args, **kwargs) -> dict:
+        """Convert the object to a dictionary.
+
+        Args:
+            *args: The arguments to pass to the base class.
+            **kwargs: The keyword arguments to pass to the base
+
+        Returns:
+            The object as a dictionary with ToastAction fields intact.
+        """
+        kwargs.setdefault("exclude_none", True)
+        d = super().dict(*args, **kwargs)
+        # Keep these fields as ToastAction so they can be serialized specially
+        if "action" in d:
+            d["action"] = self.action
+            if isinstance(self.action, dict):
+                d["action"] = ToastAction(**self.action)
+        if "cancel" in d:
+            d["cancel"] = self.cancel
+            if isinstance(self.cancel, dict):
+                d["cancel"] = ToastAction(**self.cancel)
+        if "on_dismiss" in d:
+            d["on_dismiss"] = format.format_queue_events(
+                self.on_dismiss, _toast_callback_signature
+            )
+        if "on_auto_close" in d:
+            d["on_auto_close"] = format.format_queue_events(
+                self.on_auto_close, _toast_callback_signature
+            )
+        return d
 
 
 
 
 class Toaster(Component):
 class Toaster(Component):

+ 27 - 14
reflex/components/sonner/toast.pyi

@@ -7,15 +7,16 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from reflex.style import Style
-from typing import Literal
+from typing import Any, Literal, Optional
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
+from reflex.components.props import PropsBase
 from reflex.event import EventSpec, call_script
 from reflex.event import EventSpec, call_script
 from reflex.style import Style, color_mode
 from reflex.style import Style, color_mode
 from reflex.utils import format
 from reflex.utils import format
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
-from reflex.utils.serializers import serialize
+from reflex.utils.serializers import serialize, serializer
 from reflex.vars import Var, VarData
 from reflex.vars import Var, VarData
 
 
 LiteralPosition = Literal[
 LiteralPosition = Literal[
@@ -28,20 +29,32 @@ LiteralPosition = Literal[
 ]
 ]
 toast_ref = Var.create_safe("refs['__toast']")
 toast_ref = Var.create_safe("refs['__toast']")
 
 
-class PropsBase(Base):
-    def json(self) -> str: ...
+class ToastAction(Base):
+    label: str
+    on_click: Any
+
+@serializer
+def serialize_action(action: ToastAction) -> dict: ...
 
 
 class ToastProps(PropsBase):
 class ToastProps(PropsBase):
-    description: str
-    close_button: bool
-    invert: bool
-    important: bool
-    duration: int
-    position: LiteralPosition
-    dismissible: bool
-    id: str
-    unstyled: bool
-    style: Style
+    description: Optional[str]
+    close_button: Optional[bool]
+    invert: Optional[bool]
+    important: Optional[bool]
+    duration: Optional[int]
+    position: Optional[LiteralPosition]
+    dismissible: Optional[bool]
+    action: Optional[ToastAction]
+    cancel: Optional[ToastAction]
+    id: Optional[str]
+    unstyled: Optional[bool]
+    style: Optional[Style]
+    action_button_styles: Optional[Style]
+    cancel_button_styles: Optional[Style]
+    on_dismiss: Optional[Any]
+    on_auto_close: Optional[Any]
+
+    def dict(self, *args, **kwargs) -> dict: ...
 
 
 class Toaster(Component):
 class Toaster(Component):
     @staticmethod
     @staticmethod

+ 9 - 13
reflex/event.py

@@ -4,7 +4,6 @@ from __future__ import annotations
 
 
 import inspect
 import inspect
 from base64 import b64encode
 from base64 import b64encode
-from types import FunctionType
 from typing import (
 from typing import (
     Any,
     Any,
     Callable,
     Callable,
@@ -706,7 +705,11 @@ def _callback_arg_spec(eval_result):
 
 
 def call_script(
 def call_script(
     javascript_code: str,
     javascript_code: str,
-    callback: EventHandler | Callable | None = None,
+    callback: EventSpec
+    | EventHandler
+    | Callable
+    | List[EventSpec | EventHandler | Callable]
+    | None = None,
 ) -> EventSpec:
 ) -> EventSpec:
     """Create an event handler that executes arbitrary javascript code.
     """Create an event handler that executes arbitrary javascript code.
 
 
@@ -716,21 +719,14 @@ def call_script(
 
 
     Returns:
     Returns:
         EventSpec: An event that will execute the client side javascript.
         EventSpec: An event that will execute the client side javascript.
-
-    Raises:
-        ValueError: If the callback is not a valid event handler.
     """
     """
     callback_kwargs = {}
     callback_kwargs = {}
     if callback is not None:
     if callback is not None:
-        arg_name = parse_args_spec(_callback_arg_spec)[0]._var_name
-        if isinstance(callback, EventHandler):
-            event_spec = call_event_handler(callback, _callback_arg_spec)
-        elif isinstance(callback, FunctionType):
-            event_spec = call_event_fn(callback, _callback_arg_spec)[0]
-        else:
-            raise ValueError("Cannot use {callback!r} as a call_script callback.")
         callback_kwargs = {
         callback_kwargs = {
-            "callback": f"({arg_name}) => queueEvents([{format.format_event(event_spec)}], {constants.CompileVars.SOCKET})"
+            "callback": format.format_queue_events(
+                callback,
+                args_spec=lambda result: [result],
+            )
         }
         }
     return server_side(
     return server_side(
         "_call_script",
         "_call_script",

+ 2 - 0
reflex/experimental/__init__.py

@@ -2,6 +2,7 @@
 
 
 from types import SimpleNamespace
 from types import SimpleNamespace
 
 
+from reflex.components.props import PropsBase
 from reflex.components.radix.themes.components.progress import progress as progress
 from reflex.components.radix.themes.components.progress import progress as progress
 from reflex.components.sonner.toast import toast as toast
 from reflex.components.sonner.toast import toast as toast
 
 
@@ -18,6 +19,7 @@ _x = SimpleNamespace(
     hooks=hooks,
     hooks=hooks,
     layout=layout,
     layout=layout,
     progress=progress,
     progress=progress,
+    PropsBase=PropsBase,
     run_in_thread=run_in_thread,
     run_in_thread=run_in_thread,
     toast=toast,
     toast=toast,
 )
 )

+ 73 - 2
reflex/utils/format.py

@@ -6,7 +6,7 @@ import inspect
 import json
 import json
 import os
 import os
 import re
 import re
-from typing import TYPE_CHECKING, Any, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
 
 
 from reflex import constants
 from reflex import constants
 from reflex.utils import exceptions, serializers, types
 from reflex.utils import exceptions, serializers, types
@@ -15,7 +15,7 @@ from reflex.vars import BaseVar, Var
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.components.component import ComponentStyle
     from reflex.components.component import ComponentStyle
-    from reflex.event import EventChain, EventHandler, EventSpec
+    from reflex.event import ArgsSpec, EventChain, EventHandler, EventSpec
 
 
 WRAP_MAP = {
 WRAP_MAP = {
     "{": "}",
     "{": "}",
@@ -590,6 +590,77 @@ def format_event_chain(
     )
     )
 
 
 
 
+def format_queue_events(
+    events: EventSpec
+    | EventHandler
+    | Callable
+    | List[EventSpec | EventHandler | Callable]
+    | None = None,
+    args_spec: Optional[ArgsSpec] = None,
+) -> Var[EventChain]:
+    """Format a list of event handler / event spec as a javascript callback.
+
+    The resulting code can be passed to interfaces that expect a callback
+    function and when triggered it will directly call queueEvents.
+
+    It is intended to be executed in the rx.call_script context, where some
+    existing API needs a callback to trigger a backend event handler.
+
+    Args:
+        events: The events to queue.
+        args_spec: The argument spec for the callback.
+
+    Returns:
+        The compiled javascript callback to queue the given events on the frontend.
+    """
+    from reflex.event import (
+        EventChain,
+        EventHandler,
+        EventSpec,
+        call_event_fn,
+        call_event_handler,
+    )
+
+    if not events:
+        return Var.create_safe(
+            "() => null", _var_is_string=False, _var_is_local=False
+        ).to(EventChain)
+
+    # If no spec is provided, the function will take no arguments.
+    def _default_args_spec():
+        return []
+
+    # Construct the arguments that the function accepts.
+    sig = inspect.signature(args_spec or _default_args_spec)  # type: ignore
+    if sig.parameters:
+        arg_def = ",".join(f"_{p}" for p in sig.parameters)
+        arg_def = f"({arg_def})"
+    else:
+        arg_def = "()"
+
+    payloads = []
+    if not isinstance(events, list):
+        events = [events]
+
+    # Process each event/spec/lambda (similar to Component._create_event_chain).
+    for spec in events:
+        specs: list[EventSpec] = []
+        if isinstance(spec, (EventHandler, EventSpec)):
+            specs = [call_event_handler(spec, args_spec or _default_args_spec)]
+        elif isinstance(spec, type(lambda: None)):
+            specs = call_event_fn(spec, args_spec or _default_args_spec)
+        payloads.extend(format_event(s) for s in specs)
+
+    # Return the final code snippet, expecting queueEvents, processEvent, and socket to be in scope.
+    # Typically this snippet will _only_ run from within an rx.call_script eval context.
+    return Var.create_safe(
+        f"{arg_def} => {{queueEvents([{','.join(payloads)}], {constants.CompileVars.SOCKET}); "
+        f"processEvent({constants.CompileVars.SOCKET})}}",
+        _var_is_string=False,
+        _var_is_local=False,
+    ).to(EventChain)
+
+
 def format_query_params(router_data: dict[str, Any]) -> dict[str, str]:
 def format_query_params(router_data: dict[str, Any]) -> dict[str, str]:
     """Convert back query params name to python-friendly case.
     """Convert back query params name to python-friendly case.