Pārlūkot izejas kodu

support eventspec/eventchain in var operations (#4038)

Khaleel Al-Adhami 7 mēneši atpakaļ
vecāks
revīzija
73e8a4e0ab

+ 17 - 5
reflex/.templates/web/utils/state.js

@@ -544,13 +544,19 @@ export const uploadFiles = async (
 
 
 /**
 /**
  * Create an event object.
  * Create an event object.
- * @param name The name of the event.
- * @param payload The payload of the event.
- * @param handler The client handler to process event.
+ * @param {string} name The name of the event.
+ * @param {Object.<string, Any>} payload The payload of the event.
+ * @param {Object.<string, (number|boolean)>} event_actions The actions to take on the event.
+ * @param {string} handler The client handler to process event.
  * @returns The event object.
  * @returns The event object.
  */
  */
-export const Event = (name, payload = {}, handler = null) => {
-  return { name, payload, handler };
+export const Event = (
+  name,
+  payload = {},
+  event_actions = {},
+  handler = null
+) => {
+  return { name, payload, handler, event_actions };
 };
 };
 
 
 /**
 /**
@@ -676,6 +682,12 @@ export const useEventLoop = (
     if (!(args instanceof Array)) {
     if (!(args instanceof Array)) {
       args = [args];
       args = [args];
     }
     }
+
+    event_actions = events.reduce(
+      (acc, e) => ({ ...acc, ...e.event_actions }),
+      event_actions ?? {}
+    );
+
     const _e = args.filter((o) => o?.preventDefault !== undefined)[0];
     const _e = args.filter((o) => o?.preventDefault !== undefined)[0];
 
 
     if (event_actions?.preventDefault && _e?.preventDefault) {
     if (event_actions?.preventDefault && _e?.preventDefault) {

+ 3 - 1
reflex/app.py

@@ -1536,7 +1536,9 @@ class EventNamespace(AsyncNamespace):
         """
         """
         fields = json.loads(data)
         fields = json.loads(data)
         # Get the event.
         # Get the event.
-        event = Event(**{k: v for k, v in fields.items() if k != "handler"})
+        event = Event(
+            **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}
+        )
 
 
         self.token_to_sid[event.token] = sid
         self.token_to_sid[event.token] = sid
         self.sid_to_token[sid] = event.token
         self.sid_to_token[sid] = event.token

+ 32 - 18
reflex/components/component.py

@@ -38,8 +38,10 @@ from reflex.constants import (
 )
 )
 from reflex.event import (
 from reflex.event import (
     EventChain,
     EventChain,
+    EventChainVar,
     EventHandler,
     EventHandler,
     EventSpec,
     EventSpec,
+    EventVar,
     call_event_fn,
     call_event_fn,
     call_event_handler,
     call_event_handler,
     get_handler_args,
     get_handler_args,
@@ -514,7 +516,7 @@ class Component(BaseComponent, ABC):
             Var,
             Var,
             EventHandler,
             EventHandler,
             EventSpec,
             EventSpec,
-            List[Union[EventHandler, EventSpec]],
+            List[Union[EventHandler, EventSpec, EventVar]],
             Callable,
             Callable,
         ],
         ],
     ) -> Union[EventChain, Var]:
     ) -> Union[EventChain, Var]:
@@ -532,11 +534,16 @@ class Component(BaseComponent, ABC):
         """
         """
         # If it's an event chain var, return it.
         # If it's an event chain var, return it.
         if isinstance(value, Var):
         if isinstance(value, Var):
-            if value._var_type is not EventChain:
+            if isinstance(value, EventChainVar):
+                return value
+            elif isinstance(value, EventVar):
+                value = [value]
+            elif issubclass(value._var_type, (EventChain, EventSpec)):
+                return self._create_event_chain(args_spec, value.guess_type())
+            else:
                 raise ValueError(
                 raise ValueError(
-                    f"Invalid event chain: {repr(value)} of type {type(value)}"
+                    f"Invalid event chain: {str(value)} of type {value._var_type}"
                 )
                 )
-            return value
         elif isinstance(value, EventChain):
         elif isinstance(value, EventChain):
             # Trust that the caller knows what they're doing passing an EventChain directly
             # Trust that the caller knows what they're doing passing an EventChain directly
             return value
             return value
@@ -547,7 +554,7 @@ class Component(BaseComponent, ABC):
 
 
         # If the input is a list of event handlers, create an event chain.
         # If the input is a list of event handlers, create an event chain.
         if isinstance(value, List):
         if isinstance(value, List):
-            events: list[EventSpec] = []
+            events: List[Union[EventSpec, EventVar]] = []
             for v in value:
             for v in value:
                 if isinstance(v, (EventHandler, EventSpec)):
                 if isinstance(v, (EventHandler, EventSpec)):
                     # Call the event handler to get the event.
                     # Call the event handler to get the event.
@@ -561,6 +568,8 @@ class Component(BaseComponent, ABC):
                             "lambda inside an EventChain list."
                             "lambda inside an EventChain list."
                         )
                         )
                     events.extend(result)
                     events.extend(result)
+                elif isinstance(v, EventVar):
+                    events.append(v)
                 else:
                 else:
                     raise ValueError(f"Invalid event: {v}")
                     raise ValueError(f"Invalid event: {v}")
 
 
@@ -570,32 +579,30 @@ class Component(BaseComponent, ABC):
             if isinstance(result, Var):
             if isinstance(result, Var):
                 # Recursively call this function if the lambda returned an EventChain Var.
                 # Recursively call this function if the lambda returned an EventChain Var.
                 return self._create_event_chain(args_spec, result)
                 return self._create_event_chain(args_spec, result)
-            events = result
+            events = [*result]
 
 
         # Otherwise, raise an error.
         # Otherwise, raise an error.
         else:
         else:
             raise ValueError(f"Invalid event chain: {value}")
             raise ValueError(f"Invalid event chain: {value}")
 
 
         # Add args to the event specs if necessary.
         # Add args to the event specs if necessary.
-        events = [e.with_args(get_handler_args(e)) for e in events]
-
-        # Collect event_actions from each spec
-        event_actions = {}
-        for e in events:
-            event_actions.update(e.event_actions)
+        events = [
+            (e.with_args(get_handler_args(e)) if isinstance(e, EventSpec) else e)
+            for e in events
+        ]
 
 
         # Return the event chain.
         # Return the event chain.
         if isinstance(args_spec, Var):
         if isinstance(args_spec, Var):
             return EventChain(
             return EventChain(
                 events=events,
                 events=events,
                 args_spec=None,
                 args_spec=None,
-                event_actions=event_actions,
+                event_actions={},
             )
             )
         else:
         else:
             return EventChain(
             return EventChain(
                 events=events,
                 events=events,
                 args_spec=args_spec,
                 args_spec=args_spec,
-                event_actions=event_actions,
+                event_actions={},
             )
             )
 
 
     def get_event_triggers(self) -> Dict[str, Any]:
     def get_event_triggers(self) -> Dict[str, Any]:
@@ -1030,8 +1037,11 @@ class Component(BaseComponent, ABC):
             elif isinstance(event, EventChain):
             elif isinstance(event, EventChain):
                 event_args = []
                 event_args = []
                 for spec in event.events:
                 for spec in event.events:
-                    for args in spec.args:
-                        event_args.extend(args)
+                    if isinstance(spec, EventSpec):
+                        for args in spec.args:
+                            event_args.extend(args)
+                    else:
+                        event_args.append(spec)
                 yield event_trigger, event_args
                 yield event_trigger, event_args
 
 
     def _get_vars(self, include_children: bool = False) -> list[Var]:
     def _get_vars(self, include_children: bool = False) -> list[Var]:
@@ -1105,8 +1115,12 @@ class Component(BaseComponent, ABC):
         for trigger in self.event_triggers.values():
         for trigger in self.event_triggers.values():
             if isinstance(trigger, EventChain):
             if isinstance(trigger, EventChain):
                 for event in trigger.events:
                 for event in trigger.events:
-                    if event.handler.state_full_name:
-                        return True
+                    if isinstance(event, EventSpec):
+                        if event.handler.state_full_name:
+                            return True
+                    else:
+                        if event._var_state:
+                            return True
             elif isinstance(trigger, Var) and trigger._var_state:
             elif isinstance(trigger, Var) and trigger._var_state:
                 return True
                 return True
         return False
         return False

+ 189 - 4
reflex/event.py

@@ -4,16 +4,19 @@ from __future__ import annotations
 
 
 import dataclasses
 import dataclasses
 import inspect
 import inspect
+import sys
 import types
 import types
 import urllib.parse
 import urllib.parse
 from base64 import b64encode
 from base64 import b64encode
 from typing import (
 from typing import (
     Any,
     Any,
     Callable,
     Callable,
+    ClassVar,
     Dict,
     Dict,
     List,
     List,
     Optional,
     Optional,
     Tuple,
     Tuple,
+    Type,
     Union,
     Union,
     get_type_hints,
     get_type_hints,
 )
 )
@@ -25,8 +28,15 @@ from reflex.utils import format
 from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
 from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
 from reflex.utils.types import ArgsSpec, GenericType
 from reflex.utils.types import ArgsSpec, GenericType
 from reflex.vars import VarData
 from reflex.vars import VarData
-from reflex.vars.base import LiteralVar, Var
-from reflex.vars.function import FunctionStringVar, FunctionVar
+from reflex.vars.base import (
+    CachedVarOperation,
+    LiteralNoneVar,
+    LiteralVar,
+    ToOperation,
+    Var,
+    cached_property_no_lock,
+)
+from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar
 from reflex.vars.object import ObjectVar
 from reflex.vars.object import ObjectVar
 
 
 try:
 try:
@@ -375,7 +385,7 @@ class CallableEventSpec(EventSpec):
 class EventChain(EventActionsMixin):
 class EventChain(EventActionsMixin):
     """Container for a chain of events that will be executed in order."""
     """Container for a chain of events that will be executed in order."""
 
 
-    events: List[EventSpec] = dataclasses.field(default_factory=list)
+    events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list)
 
 
     args_spec: Optional[Callable] = dataclasses.field(default=None)
     args_spec: Optional[Callable] = dataclasses.field(default=None)
 
 
@@ -478,7 +488,7 @@ class FileUpload:
             if isinstance(events, Var):
             if isinstance(events, Var):
                 raise ValueError(f"{on_upload_progress} cannot return a var {events}.")
                 raise ValueError(f"{on_upload_progress} cannot return a var {events}.")
             on_upload_progress_chain = EventChain(
             on_upload_progress_chain = EventChain(
-                events=events,
+                events=[*events],
                 args_spec=self.on_upload_progress_args_spec,
                 args_spec=self.on_upload_progress_args_spec,
             )
             )
             formatted_chain = str(format.format_prop(on_upload_progress_chain))
             formatted_chain = str(format.format_prop(on_upload_progress_chain))
@@ -1136,3 +1146,178 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
         "state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
         "state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
     )
     )
     return signature.replace(parameters=(new_param, *signature.parameters.values()))
     return signature.replace(parameters=(new_param, *signature.parameters.values()))
+
+
+class EventVar(ObjectVar):
+    """Base class for event vars."""
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
+    """A literal event var."""
+
+    _var_value: EventSpec = dataclasses.field(default=None)  # type: ignore
+
+    def __hash__(self) -> int:
+        """Get the hash of the var.
+
+        Returns:
+            The hash of the var.
+        """
+        return hash((self.__class__.__name__, self._js_expr))
+
+    @cached_property_no_lock
+    def _cached_var_name(self) -> str:
+        """The name of the var.
+
+        Returns:
+            The name of the var.
+        """
+        return str(
+            FunctionStringVar("Event").call(
+                # event handler name
+                ".".join(
+                    filter(
+                        None,
+                        format.get_event_handler_parts(self._var_value.handler),
+                    )
+                ),
+                # event handler args
+                {str(name): value for name, value in self._var_value.args},
+                # event actions
+                self._var_value.event_actions,
+                # client handler name
+                *(
+                    [self._var_value.client_handler_name]
+                    if self._var_value.client_handler_name
+                    else []
+                ),
+            )
+        )
+
+    @classmethod
+    def create(
+        cls,
+        value: EventSpec,
+        _var_data: VarData | None = None,
+    ) -> LiteralEventVar:
+        """Create a new LiteralEventVar instance.
+
+        Args:
+            value: The value of the var.
+            _var_data: The data of the var.
+
+        Returns:
+            The created LiteralEventVar instance.
+        """
+        return cls(
+            _js_expr="",
+            _var_type=EventSpec,
+            _var_data=_var_data,
+            _var_value=value,
+        )
+
+
+class EventChainVar(FunctionVar):
+    """Base class for event chain vars."""
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
+    """A literal event chain var."""
+
+    _var_value: EventChain = dataclasses.field(default=None)  # type: ignore
+
+    def __hash__(self) -> int:
+        """Get the hash of the var.
+
+        Returns:
+            The hash of the var.
+        """
+        return hash((self.__class__.__name__, self._js_expr))
+
+    @cached_property_no_lock
+    def _cached_var_name(self) -> str:
+        """The name of the var.
+
+        Returns:
+            The name of the var.
+        """
+        sig = inspect.signature(self._var_value.args_spec)  # type: ignore
+        if sig.parameters:
+            arg_def = tuple((f"_{p}" for p in sig.parameters))
+            arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
+        else:
+            # add a default argument for addEvents if none were specified in value.args_spec
+            # used to trigger the preventDefault() on the event.
+            arg_def = ("...args",)
+            arg_def_expr = Var(_js_expr="args")
+
+        return str(
+            ArgsFunctionOperation.create(
+                arg_def,
+                FunctionStringVar.create("addEvents").call(
+                    LiteralVar.create(
+                        [LiteralVar.create(event) for event in self._var_value.events]
+                    ),
+                    arg_def_expr,
+                    self._var_value.event_actions,
+                ),
+            )
+        )
+
+    @classmethod
+    def create(
+        cls,
+        value: EventChain,
+        _var_data: VarData | None = None,
+    ) -> LiteralEventChainVar:
+        """Create a new LiteralEventChainVar instance.
+
+        Args:
+            value: The value of the var.
+            _var_data: The data of the var.
+
+        Returns:
+            The created LiteralEventChainVar instance.
+        """
+        return cls(
+            _js_expr="",
+            _var_type=EventChain,
+            _var_data=_var_data,
+            _var_value=value,
+        )
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class ToEventVarOperation(ToOperation, EventVar):
+    """Result of a cast to an event var."""
+
+    _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
+
+    _default_var_type: ClassVar[Type] = EventSpec
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class ToEventChainVarOperation(ToOperation, EventChainVar):
+    """Result of a cast to an event chain var."""
+
+    _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
+
+    _default_var_type: ClassVar[Type] = EventChain

+ 1 - 13
reflex/utils/format.py

@@ -359,19 +359,7 @@ def format_prop(
 
 
         # Handle event props.
         # Handle event props.
         if isinstance(prop, EventChain):
         if isinstance(prop, EventChain):
-            sig = inspect.signature(prop.args_spec)  # type: ignore
-            if sig.parameters:
-                arg_def = ",".join(f"_{p}" for p in sig.parameters)
-                arg_def_expr = f"[{arg_def}]"
-            else:
-                # add a default argument for addEvents if none were specified in prop.args_spec
-                # used to trigger the preventDefault() on the event.
-                arg_def = "...args"
-                arg_def_expr = "args"
-
-            chain = ",".join([format_event(event) for event in prop.events])
-            event = f"addEvents([{chain}], {arg_def_expr}, {json_dumps(prop.event_actions)})"
-            prop = f"({arg_def}) => {event}"
+            return str(Var.create(prop))
 
 
         # Handle other types.
         # Handle other types.
         elif isinstance(prop, str):
         elif isinstance(prop, str):

+ 44 - 37
reflex/vars/base.py

@@ -385,6 +385,15 @@ class Var(Generic[VAR_TYPE]):
         Returns:
         Returns:
             The converted var.
             The converted var.
         """
         """
+        from reflex.event import (
+            EventChain,
+            EventChainVar,
+            EventSpec,
+            EventVar,
+            ToEventChainVarOperation,
+            ToEventVarOperation,
+        )
+
         from .function import FunctionVar, ToFunctionOperation
         from .function import FunctionVar, ToFunctionOperation
         from .number import (
         from .number import (
             BooleanVar,
             BooleanVar,
@@ -416,6 +425,10 @@ class Var(Generic[VAR_TYPE]):
             return self.to(BooleanVar, output)
             return self.to(BooleanVar, output)
         if fixed_output_type is None:
         if fixed_output_type is None:
             return ToNoneOperation.create(self)
             return ToNoneOperation.create(self)
+        if fixed_output_type is EventSpec:
+            return self.to(EventVar, output)
+        if fixed_output_type is EventChain:
+            return self.to(EventChainVar, output)
         if issubclass(fixed_output_type, Base):
         if issubclass(fixed_output_type, Base):
             return self.to(ObjectVar, output)
             return self.to(ObjectVar, output)
         if dataclasses.is_dataclass(fixed_output_type) and not issubclass(
         if dataclasses.is_dataclass(fixed_output_type) and not issubclass(
@@ -453,10 +466,13 @@ class Var(Generic[VAR_TYPE]):
         if issubclass(output, StringVar):
         if issubclass(output, StringVar):
             return ToStringOperation.create(self, var_type or str)
             return ToStringOperation.create(self, var_type or str)
 
 
-        if issubclass(output, (ObjectVar, Base)):
-            return ToObjectOperation.create(self, var_type or dict)
+        if issubclass(output, EventVar):
+            return ToEventVarOperation.create(self, var_type or EventSpec)
 
 
-        if dataclasses.is_dataclass(output):
+        if issubclass(output, EventChainVar):
+            return ToEventChainVarOperation.create(self, var_type or EventChain)
+
+        if issubclass(output, (ObjectVar, Base)):
             return ToObjectOperation.create(self, var_type or dict)
             return ToObjectOperation.create(self, var_type or dict)
 
 
         if issubclass(output, FunctionVar):
         if issubclass(output, FunctionVar):
@@ -469,6 +485,9 @@ class Var(Generic[VAR_TYPE]):
         if issubclass(output, NoneVar):
         if issubclass(output, NoneVar):
             return ToNoneOperation.create(self)
             return ToNoneOperation.create(self)
 
 
+        if dataclasses.is_dataclass(output):
+            return ToObjectOperation.create(self, var_type or dict)
+
         # If we can't determine the first argument, we just replace the _var_type.
         # If we can't determine the first argument, we just replace the _var_type.
         if not issubclass(output, Var) or var_type is None:
         if not issubclass(output, Var) or var_type is None:
             return dataclasses.replace(
             return dataclasses.replace(
@@ -494,6 +513,8 @@ class Var(Generic[VAR_TYPE]):
         Raises:
         Raises:
             TypeError: If the type is not supported for guessing.
             TypeError: If the type is not supported for guessing.
         """
         """
+        from reflex.event import EventChain, EventChainVar, EventSpec, EventVar
+
         from .number import BooleanVar, NumberVar
         from .number import BooleanVar, NumberVar
         from .object import ObjectVar
         from .object import ObjectVar
         from .sequence import ArrayVar, StringVar
         from .sequence import ArrayVar, StringVar
@@ -539,6 +560,10 @@ class Var(Generic[VAR_TYPE]):
             return self.to(ArrayVar, self._var_type)
             return self.to(ArrayVar, self._var_type)
         if issubclass(fixed_type, str):
         if issubclass(fixed_type, str):
             return self.to(StringVar, self._var_type)
             return self.to(StringVar, self._var_type)
+        if issubclass(fixed_type, EventSpec):
+            return self.to(EventVar, self._var_type)
+        if issubclass(fixed_type, EventChain):
+            return self.to(EventChainVar, self._var_type)
         if issubclass(fixed_type, Base):
         if issubclass(fixed_type, Base):
             return self.to(ObjectVar, self._var_type)
             return self.to(ObjectVar, self._var_type)
         if dataclasses.is_dataclass(fixed_type):
         if dataclasses.is_dataclass(fixed_type):
@@ -1029,47 +1054,22 @@ class LiteralVar(Var):
         if value is None:
         if value is None:
             return LiteralNoneVar.create(_var_data=_var_data)
             return LiteralNoneVar.create(_var_data=_var_data)
 
 
-        from reflex.event import EventChain, EventHandler, EventSpec
+        from reflex.event import (
+            EventChain,
+            EventHandler,
+            EventSpec,
+            LiteralEventChainVar,
+            LiteralEventVar,
+        )
         from reflex.utils.format import get_event_handler_parts
         from reflex.utils.format import get_event_handler_parts
 
 
-        from .function import ArgsFunctionOperation, FunctionStringVar
         from .object import LiteralObjectVar
         from .object import LiteralObjectVar
 
 
         if isinstance(value, EventSpec):
         if isinstance(value, EventSpec):
-            event_name = LiteralVar.create(
-                ".".join(filter(None, get_event_handler_parts(value.handler)))
-            )
-            event_args = LiteralVar.create(
-                {str(name): value for name, value in value.args}
-            )
-            event_client_name = LiteralVar.create(value.client_handler_name)
-            return FunctionStringVar("Event").call(
-                event_name,
-                event_args,
-                *([event_client_name] if value.client_handler_name else []),
-            )
+            return LiteralEventVar.create(value, _var_data=_var_data)
 
 
         if isinstance(value, EventChain):
         if isinstance(value, EventChain):
-            sig = inspect.signature(value.args_spec)  # type: ignore
-            if sig.parameters:
-                arg_def = tuple((f"_{p}" for p in sig.parameters))
-                arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
-            else:
-                # add a default argument for addEvents if none were specified in value.args_spec
-                # used to trigger the preventDefault() on the event.
-                arg_def = ("...args",)
-                arg_def_expr = Var(_js_expr="args")
-
-            return ArgsFunctionOperation.create(
-                arg_def,
-                FunctionStringVar.create("addEvents").call(
-                    LiteralVar.create(
-                        [LiteralVar.create(event) for event in value.events]
-                    ),
-                    arg_def_expr,
-                    LiteralVar.create(value.event_actions),
-                ),
-            )
+            return LiteralEventChainVar.create(value, _var_data=_var_data)
 
 
         if isinstance(value, EventHandler):
         if isinstance(value, EventHandler):
             return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value))))
             return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value))))
@@ -2126,9 +2126,16 @@ class NoneVar(Var[None]):
     """A var representing None."""
     """A var representing None."""
 
 
 
 
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
 class LiteralNoneVar(LiteralVar, NoneVar):
 class LiteralNoneVar(LiteralVar, NoneVar):
     """A var representing None."""
     """A var representing None."""
 
 
+    _var_value: None = None
+
     def json(self) -> str:
     def json(self) -> str:
         """Serialize the var to a JSON string.
         """Serialize the var to a JSON string.
 
 

+ 3 - 3
tests/units/components/base/test_script.py

@@ -58,14 +58,14 @@ def test_script_event_handler():
     )
     )
     render_dict = component.render()
     render_dict = component.render()
     assert (
     assert (
-        f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{  }})))], args, ({{  }})))))}}'
+        f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{  }}), ({{  }})))], args, ({{  }})))))}}'
         in render_dict["props"]
         in render_dict["props"]
     )
     )
     assert (
     assert (
-        f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{  }})))], args, ({{  }})))))}}'
+        f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{  }}), ({{  }})))], args, ({{  }})))))}}'
         in render_dict["props"]
         in render_dict["props"]
     )
     )
     assert (
     assert (
-        f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{  }})))], args, ({{  }})))))}}'
+        f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{  }}), ({{  }})))], args, ({{  }})))))}}'
         in render_dict["props"]
         in render_dict["props"]
     )
     )

+ 4 - 4
tests/units/components/test_component.py

@@ -832,7 +832,7 @@ def test_component_event_trigger_arbitrary_args():
 
 
     assert comp.render()["props"][0] == (
     assert comp.render()["props"][0] == (
         "onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents("
         "onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents("
-        f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }})))], '
+        f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{  }})))], '
         "[__e, _alpha, _bravo, _charlie], ({  })))))}"
         "[__e, _alpha, _bravo, _charlie], ({  })))))}"
     )
     )
 
 
@@ -1178,7 +1178,7 @@ TEST_VAR = LiteralVar.create("test")._replace(
 )
 )
 FORMATTED_TEST_VAR = LiteralVar.create(f"foo{TEST_VAR}bar")
 FORMATTED_TEST_VAR = LiteralVar.create(f"foo{TEST_VAR}bar")
 STYLE_VAR = TEST_VAR._replace(_js_expr="style")
 STYLE_VAR = TEST_VAR._replace(_js_expr="style")
-EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain)
+EVENT_CHAIN_VAR = TEST_VAR.to(EventChain)
 ARG_VAR = Var(_js_expr="arg")
 ARG_VAR = Var(_js_expr="arg")
 
 
 TEST_VAR_DICT_OF_DICT = LiteralVar.create({"a": {"b": "test"}})._replace(
 TEST_VAR_DICT_OF_DICT = LiteralVar.create({"a": {"b": "test"}})._replace(
@@ -2159,7 +2159,7 @@ class TriggerState(rx.State):
                 rx.text("random text", on_click=TriggerState.do_something),
                 rx.text("random text", on_click=TriggerState.do_something),
                 rx.text(
                 rx.text(
                     "random text",
                     "random text",
-                    on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain),
+                    on_click=Var(_js_expr="toggleColorMode").to(EventChain),
                 ),
                 ),
             ),
             ),
             True,
             True,
@@ -2169,7 +2169,7 @@ class TriggerState(rx.State):
                 rx.text("random text", on_click=rx.console_log("log")),
                 rx.text("random text", on_click=rx.console_log("log")),
                 rx.text(
                 rx.text(
                     "random text",
                     "random text",
-                    on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain),
+                    on_click=Var(_js_expr="toggleColorMode").to(EventChain),
                 ),
                 ),
             ),
             ),
             False,
             False,

+ 17 - 5
tests/units/utils/test_format.py

@@ -374,7 +374,7 @@ def test_format_match(
                 events=[EventSpec(handler=EventHandler(fn=mock_event))],
                 events=[EventSpec(handler=EventHandler(fn=mock_event))],
                 args_spec=lambda: [],
                 args_spec=lambda: [],
             ),
             ),
-            '((...args) => ((addEvents([(Event("mock_event", ({  })))], args, ({  })))))',
+            '((...args) => ((addEvents([(Event("mock_event", ({  }), ({  })))], args, ({  })))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -395,7 +395,7 @@ def test_format_match(
                 ],
                 ],
                 args_spec=lambda e: [e.target.value],
                 args_spec=lambda e: [e.target.value],
             ),
             ),
-            '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] })))], [_e], ({  })))))',
+            '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({  })))], [_e], ({  })))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -403,7 +403,19 @@ def test_format_match(
                 args_spec=lambda: [],
                 args_spec=lambda: [],
                 event_actions={"stopPropagation": True},
                 event_actions={"stopPropagation": True},
             ),
             ),
-            '((...args) => ((addEvents([(Event("mock_event", ({  })))], args, ({ ["stopPropagation"] : true })))))',
+            '((...args) => ((addEvents([(Event("mock_event", ({  }), ({  })))], args, ({ ["stopPropagation"] : true })))))',
+        ),
+        (
+            EventChain(
+                events=[
+                    EventSpec(
+                        handler=EventHandler(fn=mock_event),
+                        event_actions={"stopPropagation": True},
+                    )
+                ],
+                args_spec=lambda: [],
+            ),
+            '((...args) => ((addEvents([(Event("mock_event", ({  }), ({ ["stopPropagation"] : true })))], args, ({  })))))',
         ),
         ),
         (
         (
             EventChain(
             EventChain(
@@ -411,7 +423,7 @@ def test_format_match(
                 args_spec=lambda: [],
                 args_spec=lambda: [],
                 event_actions={"preventDefault": True},
                 event_actions={"preventDefault": True},
             ),
             ),
-            '((...args) => ((addEvents([(Event("mock_event", ({  })))], args, ({ ["preventDefault"] : true })))))',
+            '((...args) => ((addEvents([(Event("mock_event", ({  }), ({  })))], args, ({ ["preventDefault"] : true })))))',
         ),
         ),
         ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'),
         ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'),
         (Var(_js_expr="var", _var_type=int).guess_type(), "var"),
         (Var(_js_expr="var", _var_type=int).guess_type(), "var"),
@@ -519,7 +531,7 @@ def test_format_event_handler(input, output):
     [
     [
         (
         (
             EventSpec(handler=EventHandler(fn=mock_event)),
             EventSpec(handler=EventHandler(fn=mock_event)),
-            '(Event("mock_event", ({  })))',
+            '(Event("mock_event", ({  }), ({  })))',
         ),
         ),
     ],
     ],
 )
 )