Переглянути джерело

allow for event handlers to ignore args (#4282)

* allow for event handlers to ignore args

* use a constant

* dang it darglint

* forgor

* keep the tests but move them to valid place
Khaleel Al-Adhami 6 місяців тому
батько
коміт
6334cfab0d

+ 12 - 5
reflex/components/component.py

@@ -17,6 +17,7 @@ from typing import (
     Iterator,
     Iterator,
     List,
     List,
     Optional,
     Optional,
+    Sequence,
     Set,
     Set,
     Type,
     Type,
     Union,
     Union,
@@ -38,6 +39,7 @@ from reflex.constants import (
     PageNames,
     PageNames,
 )
 )
 from reflex.constants.compiler import SpecialAttributes
 from reflex.constants.compiler import SpecialAttributes
+from reflex.constants.state import FRONTEND_EVENT_STATE
 from reflex.event import (
 from reflex.event import (
     EventCallback,
     EventCallback,
     EventChain,
     EventChain,
@@ -533,7 +535,7 @@ class Component(BaseComponent, ABC):
 
 
     def _create_event_chain(
     def _create_event_chain(
         self,
         self,
-        args_spec: Any,
+        args_spec: types.ArgsSpec | Sequence[types.ArgsSpec],
         value: Union[
         value: Union[
             Var,
             Var,
             EventHandler,
             EventHandler,
@@ -599,7 +601,7 @@ class Component(BaseComponent, ABC):
 
 
         # If the input is a callable, create an event chain.
         # If the input is a callable, create an event chain.
         elif isinstance(value, Callable):
         elif isinstance(value, Callable):
-            result = call_event_fn(value, args_spec)
+            result = call_event_fn(value, args_spec, key=key)
             if isinstance(result, Var):
             if isinstance(result, Var):
                 # Recursively call this function if the lambda returned an EventChain Var.
                 # Recursively call this function if the lambda returned an EventChain Var.
                 return self._create_event_chain(args_spec, result, key=key)
                 return self._create_event_chain(args_spec, result, key=key)
@@ -629,14 +631,16 @@ class Component(BaseComponent, ABC):
                 event_actions={},
                 event_actions={},
             )
             )
 
 
-    def get_event_triggers(self) -> Dict[str, Any]:
+    def get_event_triggers(
+        self,
+    ) -> Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]]:
         """Get the event triggers for the component.
         """Get the event triggers for the component.
 
 
         Returns:
         Returns:
             The event triggers.
             The event triggers.
 
 
         """
         """
-        default_triggers = {
+        default_triggers: Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
             EventTriggers.ON_FOCUS: no_args_event_spec,
             EventTriggers.ON_FOCUS: no_args_event_spec,
             EventTriggers.ON_BLUR: no_args_event_spec,
             EventTriggers.ON_BLUR: no_args_event_spec,
             EventTriggers.ON_CLICK: no_args_event_spec,
             EventTriggers.ON_CLICK: no_args_event_spec,
@@ -1142,7 +1146,10 @@ class Component(BaseComponent, ABC):
                     if isinstance(event, EventCallback):
                     if isinstance(event, EventCallback):
                         continue
                         continue
                     if isinstance(event, EventSpec):
                     if isinstance(event, EventSpec):
-                        if event.handler.state_full_name:
+                        if (
+                            event.handler.state_full_name
+                            and event.handler.state_full_name != FRONTEND_EVENT_STATE
+                        ):
                             return True
                             return True
                     else:
                     else:
                         if event._var_state:
                         if event._var_state:

+ 4 - 0
reflex/constants/state.py

@@ -9,3 +9,7 @@ class StateManagerMode(str, Enum):
     DISK = "disk"
     DISK = "disk"
     MEMORY = "memory"
     MEMORY = "memory"
     REDIS = "redis"
     REDIS = "redis"
+
+
+# Used for things like console_log, etc.
+FRONTEND_EVENT_STATE = "__reflex_internal_frontend_event_state"

+ 73 - 64
reflex/event.py

@@ -28,10 +28,10 @@ from typing import (
 from typing_extensions import ParamSpec, Protocol, get_args, get_origin
 from typing_extensions import ParamSpec, Protocol, get_args, get_origin
 
 
 from reflex import constants
 from reflex import constants
+from reflex.constants.state import FRONTEND_EVENT_STATE
 from reflex.utils import console, format
 from reflex.utils import console, format
 from reflex.utils.exceptions import (
 from reflex.utils.exceptions import (
     EventFnArgMismatch,
     EventFnArgMismatch,
-    EventHandlerArgMismatch,
     EventHandlerArgTypeMismatch,
     EventHandlerArgTypeMismatch,
 )
 )
 from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
 from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
@@ -662,7 +662,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
     fn.__qualname__ = name
     fn.__qualname__ = name
     fn.__signature__ = sig
     fn.__signature__ = sig
     return EventSpec(
     return EventSpec(
-        handler=EventHandler(fn=fn),
+        handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
         args=tuple(
         args=tuple(
             (
             (
                 Var(_js_expr=k),
                 Var(_js_expr=k),
@@ -1092,8 +1092,8 @@ def get_hydrate_event(state) -> str:
 
 
 
 
 def call_event_handler(
 def call_event_handler(
-    event_handler: EventHandler | EventSpec,
-    arg_spec: ArgsSpec | Sequence[ArgsSpec],
+    event_callback: EventHandler | EventSpec,
+    event_spec: ArgsSpec | Sequence[ArgsSpec],
     key: Optional[str] = None,
     key: Optional[str] = None,
 ) -> EventSpec:
 ) -> EventSpec:
     """Call an event handler to get the event spec.
     """Call an event handler to get the event spec.
@@ -1103,53 +1103,57 @@ def call_event_handler(
     Otherwise, the event handler will be called with no args.
     Otherwise, the event handler will be called with no args.
 
 
     Args:
     Args:
-        event_handler: The event handler.
-        arg_spec: The lambda that define the argument(s) to pass to the event handler.
+        event_callback: The event handler.
+        event_spec: The lambda that define the argument(s) to pass to the event handler.
         key: The key to pass to the event handler.
         key: The key to pass to the event handler.
 
 
-    Raises:
-        EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
-
     Returns:
     Returns:
         The event spec from calling the event handler.
         The event spec from calling the event handler.
 
 
     # noqa: DAR401 failure
     # noqa: DAR401 failure
 
 
     """
     """
-    parsed_args = parse_args_spec(arg_spec)  # type: ignore
-
-    if isinstance(event_handler, EventSpec):
-        # Handle partial application of EventSpec args
-        return event_handler.add_args(*parsed_args)
-
-    provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)
-
-    provided_callback_n_args = (
-        len(provided_callback_fullspec.args) - 1
-    )  # subtract 1 for bound self arg
-
-    if provided_callback_n_args != len(parsed_args):
-        raise EventHandlerArgMismatch(
-            "The number of arguments accepted by "
-            f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
-            "does not match the arguments passed by the event trigger: "
-            f"{[str(v) for v in parsed_args]}\n"
-            "See https://reflex.dev/docs/events/event-arguments/"
+    event_spec_args = parse_args_spec(event_spec)  # type: ignore
+
+    if isinstance(event_callback, EventSpec):
+        check_fn_match_arg_spec(
+            event_callback.handler.fn,
+            event_spec,
+            key,
+            bool(event_callback.handler.state_full_name) + len(event_callback.args),
+            event_callback.handler.fn.__qualname__,
         )
         )
+        # Handle partial application of EventSpec args
+        return event_callback.add_args(*event_spec_args)
+
+    check_fn_match_arg_spec(
+        event_callback.fn,
+        event_spec,
+        key,
+        bool(event_callback.state_full_name),
+        event_callback.fn.__qualname__,
+    )
 
 
-    all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
+    all_acceptable_specs = (
+        [event_spec] if not isinstance(event_spec, Sequence) else event_spec
+    )
 
 
     event_spec_return_types = list(
     event_spec_return_types = list(
         filter(
         filter(
             lambda event_spec_return_type: event_spec_return_type is not None
             lambda event_spec_return_type: event_spec_return_type is not None
             and get_origin(event_spec_return_type) is tuple,
             and get_origin(event_spec_return_type) is tuple,
-            (get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
+            (
+                get_type_hints(arg_spec).get("return", None)
+                for arg_spec in all_acceptable_specs
+            ),
         )
         )
     )
     )
 
 
     if event_spec_return_types:
     if event_spec_return_types:
         failures = []
         failures = []
 
 
+        event_callback_spec = inspect.getfullargspec(event_callback.fn)
+
         for event_spec_index, event_spec_return_type in enumerate(
         for event_spec_index, event_spec_return_type in enumerate(
             event_spec_return_types
             event_spec_return_types
         ):
         ):
@@ -1160,14 +1164,14 @@ def call_event_handler(
             ]
             ]
 
 
             try:
             try:
-                type_hints_of_provided_callback = get_type_hints(event_handler.fn)
+                type_hints_of_provided_callback = get_type_hints(event_callback.fn)
             except NameError:
             except NameError:
                 type_hints_of_provided_callback = {}
                 type_hints_of_provided_callback = {}
 
 
             failed_type_check = False
             failed_type_check = False
 
 
             # check that args of event handler are matching the spec if type hints are provided
             # check that args of event handler are matching the spec if type hints are provided
-            for i, arg in enumerate(provided_callback_fullspec.args[1:]):
+            for i, arg in enumerate(event_callback_spec.args[1:]):
                 if arg not in type_hints_of_provided_callback:
                 if arg not in type_hints_of_provided_callback:
                     continue
                     continue
 
 
@@ -1181,7 +1185,7 @@ def call_event_handler(
                     #     f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
                     #     f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
                     # ) from e
                     # ) from e
                     console.warn(
                     console.warn(
-                        f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
+                        f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}."
                     )
                     )
                     compare_result = False
                     compare_result = False
 
 
@@ -1189,7 +1193,7 @@ def call_event_handler(
                     continue
                     continue
                 else:
                 else:
                     failure = EventHandlerArgTypeMismatch(
                     failure = EventHandlerArgTypeMismatch(
-                        f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
+                        f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
                     )
                     )
                     failures.append(failure)
                     failures.append(failure)
                     failed_type_check = True
                     failed_type_check = True
@@ -1210,14 +1214,14 @@ def call_event_handler(
 
 
                     given_string = ", ".join(
                     given_string = ", ".join(
                         repr(type_hints_of_provided_callback.get(arg, Any))
                         repr(type_hints_of_provided_callback.get(arg, Any))
-                        for arg in provided_callback_fullspec.args[1:]
+                        for arg in event_callback_spec.args[1:]
                     ).replace("[", "\\[")
                     ).replace("[", "\\[")
 
 
                     console.warn(
                     console.warn(
-                        f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
+                        f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. "
                         f"This may lead to unexpected behavior but is intentionally ignored for {key}."
                         f"This may lead to unexpected behavior but is intentionally ignored for {key}."
                     )
                     )
-                return event_handler(*parsed_args)
+                return event_callback(*event_spec_args)
 
 
         if failures:
         if failures:
             console.deprecate(
             console.deprecate(
@@ -1227,7 +1231,7 @@ def call_event_handler(
                 "0.7.0",
                 "0.7.0",
             )
             )
 
 
-    return event_handler(*parsed_args)  # type: ignore
+    return event_callback(*event_spec_args)  # type: ignore
 
 
 
 
 def unwrap_var_annotation(annotation: GenericType):
 def unwrap_var_annotation(annotation: GenericType):
@@ -1294,45 +1298,46 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
 
 
 
 
 def check_fn_match_arg_spec(
 def check_fn_match_arg_spec(
-    fn: Callable,
-    arg_spec: ArgsSpec,
-    key: Optional[str] = None,
-) -> List[Var]:
+    user_func: Callable,
+    arg_spec: ArgsSpec | Sequence[ArgsSpec],
+    key: str | None = None,
+    number_of_bound_args: int = 0,
+    func_name: str | None = None,
+):
     """Ensures that the function signature matches the passed argument specification
     """Ensures that the function signature matches the passed argument specification
     or raises an EventFnArgMismatch if they do not.
     or raises an EventFnArgMismatch if they do not.
 
 
     Args:
     Args:
-        fn: The function to be validated.
+        user_func: The function to be validated.
         arg_spec: The argument specification for the event trigger.
         arg_spec: The argument specification for the event trigger.
-        key: The key to pass to the event handler.
-
-    Returns:
-        The parsed arguments from the argument specification.
+        key: The key of the event trigger.
+        number_of_bound_args: The number of bound arguments to the function.
+        func_name: The name of the function to be validated.
 
 
     Raises:
     Raises:
         EventFnArgMismatch: Raised if the number of mandatory arguments do not match
         EventFnArgMismatch: Raised if the number of mandatory arguments do not match
     """
     """
-    fn_args = inspect.getfullargspec(fn).args
-    fn_defaults_args = inspect.getfullargspec(fn).defaults
-    n_fn_args = len(fn_args)
-    n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
-    if isinstance(fn, types.MethodType):
-        n_fn_args -= 1  # subtract 1 for bound self arg
-    parsed_args = parse_args_spec(arg_spec)
-    if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
+    user_args = inspect.getfullargspec(user_func).args
+    user_default_args = inspect.getfullargspec(user_func).defaults
+    number_of_user_args = len(user_args) - number_of_bound_args
+    number_of_user_default_args = len(user_default_args) if user_default_args else 0
+
+    parsed_event_args = parse_args_spec(arg_spec)
+
+    number_of_event_args = len(parsed_event_args)
+
+    if number_of_user_args - number_of_user_default_args > number_of_event_args:
         raise EventFnArgMismatch(
         raise EventFnArgMismatch(
-            "The number of mandatory arguments accepted by "
-            f"{fn} ({n_fn_args - n_fn_defaults_args}) "
-            "does not match the arguments passed by the event trigger: "
-            f"{[str(v) for v in parsed_args]}\n"
+            f"Event {key} only provides {number_of_event_args} arguments, but "
+            f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} "
+            "arguments to be passed to the event handler.\n"
             "See https://reflex.dev/docs/events/event-arguments/"
             "See https://reflex.dev/docs/events/event-arguments/"
         )
         )
-    return parsed_args
 
 
 
 
 def call_event_fn(
 def call_event_fn(
     fn: Callable,
     fn: Callable,
-    arg_spec: ArgsSpec,
+    arg_spec: ArgsSpec | Sequence[ArgsSpec],
     key: Optional[str] = None,
     key: Optional[str] = None,
 ) -> list[EventSpec] | Var:
 ) -> list[EventSpec] | Var:
     """Call a function to a list of event specs.
     """Call a function to a list of event specs.
@@ -1356,10 +1361,14 @@ def call_event_fn(
     from reflex.utils.exceptions import EventHandlerValueError
     from reflex.utils.exceptions import EventHandlerValueError
 
 
     # Check that fn signature matches arg_spec
     # Check that fn signature matches arg_spec
-    parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
+    check_fn_match_arg_spec(fn, arg_spec, key=key)
+
+    parsed_args = parse_args_spec(arg_spec)
+
+    number_of_fn_args = len(inspect.getfullargspec(fn).args)
 
 
     # Call the function with the parsed args.
     # Call the function with the parsed args.
-    out = fn(*parsed_args)
+    out = fn(*[*parsed_args][:number_of_fn_args])
 
 
     # If the function returns a Var, assume it's an EventChain and render it directly.
     # If the function returns a Var, assume it's an EventChain and render it directly.
     if isinstance(out, Var):
     if isinstance(out, Var):
@@ -1478,7 +1487,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
     """
     """
     signature = inspect.signature(fn)
     signature = inspect.signature(fn)
     new_param = inspect.Parameter(
     new_param = inspect.Parameter(
-        "state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
+        FRONTEND_EVENT_STATE, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
     )
     )
     return signature.replace(parameters=(new_param, *signature.parameters.values()))
     return signature.replace(parameters=(new_param, *signature.parameters.values()))
 
 

+ 1 - 5
reflex/utils/exceptions.py

@@ -89,16 +89,12 @@ class MatchTypeError(ReflexError, TypeError):
     """Raised when the return types of match cases are different."""
     """Raised when the return types of match cases are different."""
 
 
 
 
-class EventHandlerArgMismatch(ReflexError, TypeError):
-    """Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""
-
-
 class EventHandlerArgTypeMismatch(ReflexError, TypeError):
 class EventHandlerArgTypeMismatch(ReflexError, TypeError):
     """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
     """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
 
 
 
 
 class EventFnArgMismatch(ReflexError, TypeError):
 class EventFnArgMismatch(ReflexError, TypeError):
-    """Raised when the number of args accepted by a lambda differs from that provided by the event trigger."""
+    """Raised when the number of args required by an event handler is more than provided by the event trigger."""
 
 
 
 
 class DynamicRouteArgShadowsStateVar(ReflexError, NameError):
 class DynamicRouteArgShadowsStateVar(ReflexError, NameError):

+ 2 - 1
reflex/utils/format.py

@@ -9,6 +9,7 @@ import re
 from typing import TYPE_CHECKING, Any, List, Optional, Union
 from typing import TYPE_CHECKING, Any, List, Optional, Union
 
 
 from reflex import constants
 from reflex import constants
+from reflex.constants.state import FRONTEND_EVENT_STATE
 from reflex.utils import exceptions
 from reflex.utils import exceptions
 from reflex.utils.console import deprecate
 from reflex.utils.console import deprecate
 
 
@@ -439,7 +440,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
 
 
     from reflex.state import State
     from reflex.state import State
 
 
-    if state_full_name == "state" and name not in State.__dict__:
+    if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__:
         return ("", to_snake_case(handler.fn.__qualname__))
         return ("", to_snake_case(handler.fn.__qualname__))
 
 
     return (state_full_name, name)
     return (state_full_name, name)

+ 2 - 2
reflex/utils/pyi_generator.py

@@ -16,7 +16,7 @@ from itertools import chain
 from multiprocessing import Pool, cpu_count
 from multiprocessing import Pool, cpu_count
 from pathlib import Path
 from pathlib import Path
 from types import ModuleType, SimpleNamespace
 from types import ModuleType, SimpleNamespace
-from typing import Any, Callable, Iterable, Type, get_args, get_origin
+from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
 
 
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.utils import types as rx_types
 from reflex.utils import types as rx_types
@@ -560,7 +560,7 @@ def _generate_component_create_functiondef(
                                     inspect.signature(event_specs).return_annotation
                                     inspect.signature(event_specs).return_annotation
                                 )
                                 )
                                 if not isinstance(
                                 if not isinstance(
-                                    event_specs := event_triggers[trigger], tuple
+                                    event_specs := event_triggers[trigger], Sequence
                                 )
                                 )
                                 else ast.Subscript(
                                 else ast.Subscript(
                                     ast.Name("Union"),
                                     ast.Name("Union"),

+ 27 - 36
tests/units/components/test_component.py

@@ -29,7 +29,6 @@ from reflex.style import Style
 from reflex.utils import imports
 from reflex.utils import imports
 from reflex.utils.exceptions import (
 from reflex.utils.exceptions import (
     EventFnArgMismatch,
     EventFnArgMismatch,
-    EventHandlerArgMismatch,
 )
 )
 from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.vars import VarData
 from reflex.vars import VarData
@@ -907,26 +906,14 @@ def test_invalid_event_handler_args(component2, test_state):
         test_state: A test state.
         test_state: A test state.
     """
     """
     # EventHandler args must match
     # EventHandler args must match
-    with pytest.raises(EventHandlerArgMismatch):
+    with pytest.raises(EventFnArgMismatch):
         component2.create(on_click=test_state.do_something_arg)
         component2.create(on_click=test_state.do_something_arg)
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(on_open=test_state.do_something)
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(on_prop_event=test_state.do_something)
 
 
     # Multiple EventHandler args: all must match
     # Multiple EventHandler args: all must match
-    with pytest.raises(EventHandlerArgMismatch):
+    with pytest.raises(EventFnArgMismatch):
         component2.create(
         component2.create(
             on_click=[test_state.do_something_arg, test_state.do_something]
             on_click=[test_state.do_something_arg, test_state.do_something]
         )
         )
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(
-            on_open=[test_state.do_something_arg, test_state.do_something]
-        )
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(
-            on_prop_event=[test_state.do_something_arg, test_state.do_something]
-        )
 
 
     # Enable when 0.7.0 happens
     # Enable when 0.7.0 happens
     # # Event Handler types must match
     # # Event Handler types must match
@@ -957,38 +944,19 @@ def test_invalid_event_handler_args(component2, test_state):
     # lambda signature must match event trigger.
     # lambda signature must match event trigger.
     with pytest.raises(EventFnArgMismatch):
     with pytest.raises(EventFnArgMismatch):
         component2.create(on_click=lambda _: test_state.do_something_arg(1))
         component2.create(on_click=lambda _: test_state.do_something_arg(1))
-    with pytest.raises(EventFnArgMismatch):
-        component2.create(on_open=lambda: test_state.do_something)
-    with pytest.raises(EventFnArgMismatch):
-        component2.create(on_prop_event=lambda: test_state.do_something)
 
 
     # lambda returning EventHandler must match spec
     # lambda returning EventHandler must match spec
-    with pytest.raises(EventHandlerArgMismatch):
+    with pytest.raises(EventFnArgMismatch):
         component2.create(on_click=lambda: test_state.do_something_arg)
         component2.create(on_click=lambda: test_state.do_something_arg)
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(on_open=lambda _: test_state.do_something)
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(on_prop_event=lambda _: test_state.do_something)
 
 
     # Mixed EventSpec and EventHandler must match spec.
     # Mixed EventSpec and EventHandler must match spec.
-    with pytest.raises(EventHandlerArgMismatch):
+    with pytest.raises(EventFnArgMismatch):
         component2.create(
         component2.create(
             on_click=lambda: [
             on_click=lambda: [
                 test_state.do_something_arg(1),
                 test_state.do_something_arg(1),
                 test_state.do_something_arg,
                 test_state.do_something_arg,
             ]
             ]
         )
         )
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(
-            on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
-        )
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(
-            on_prop_event=lambda _: [
-                test_state.do_something_arg(1),
-                test_state.do_something,
-            ]
-        )
 
 
 
 
 def test_valid_event_handler_args(component2, test_state):
 def test_valid_event_handler_args(component2, test_state):
@@ -1002,6 +970,10 @@ def test_valid_event_handler_args(component2, test_state):
     component2.create(on_click=test_state.do_something)
     component2.create(on_click=test_state.do_something)
     component2.create(on_click=test_state.do_something_arg(1))
     component2.create(on_click=test_state.do_something_arg(1))
 
 
+    # Does not raise because event handlers are allowed to have less args than the spec.
+    component2.create(on_open=test_state.do_something)
+    component2.create(on_prop_event=test_state.do_something)
+
     # Controlled event handlers should take args.
     # Controlled event handlers should take args.
     component2.create(on_open=test_state.do_something_arg)
     component2.create(on_open=test_state.do_something_arg)
     component2.create(on_prop_event=test_state.do_something_arg)
     component2.create(on_prop_event=test_state.do_something_arg)
@@ -1010,10 +982,20 @@ def test_valid_event_handler_args(component2, test_state):
     component2.create(on_open=test_state.do_something())
     component2.create(on_open=test_state.do_something())
     component2.create(on_prop_event=test_state.do_something())
     component2.create(on_prop_event=test_state.do_something())
 
 
+    # Multiple EventHandler args: all must match
+    component2.create(on_open=[test_state.do_something_arg, test_state.do_something])
+    component2.create(
+        on_prop_event=[test_state.do_something_arg, test_state.do_something]
+    )
+
     # lambda returning EventHandler is okay if the spec matches.
     # lambda returning EventHandler is okay if the spec matches.
     component2.create(on_click=lambda: test_state.do_something)
     component2.create(on_click=lambda: test_state.do_something)
     component2.create(on_open=lambda _: test_state.do_something_arg)
     component2.create(on_open=lambda _: test_state.do_something_arg)
     component2.create(on_prop_event=lambda _: test_state.do_something_arg)
     component2.create(on_prop_event=lambda _: test_state.do_something_arg)
+    component2.create(on_open=lambda: test_state.do_something)
+    component2.create(on_prop_event=lambda: test_state.do_something)
+    component2.create(on_open=lambda _: test_state.do_something)
+    component2.create(on_prop_event=lambda _: test_state.do_something)
 
 
     # lambda can always return an EventSpec.
     # lambda can always return an EventSpec.
     component2.create(on_click=lambda: test_state.do_something_arg(1))
     component2.create(on_click=lambda: test_state.do_something_arg(1))
@@ -1046,6 +1028,15 @@ def test_valid_event_handler_args(component2, test_state):
     component2.create(
     component2.create(
         on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()]
         on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()]
     )
     )
+    component2.create(
+        on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
+    )
+    component2.create(
+        on_prop_event=lambda _: [
+            test_state.do_something_arg(1),
+            test_state.do_something,
+        ]
+    )
 
 
 
 
 def test_get_hooks_nested(component1, component2, component3):
 def test_get_hooks_nested(component1, component2, component3):

+ 6 - 3
tests/units/test_event.py

@@ -107,7 +107,7 @@ def test_call_event_handler_partial():
     def spec(a2: Var[str]) -> List[Var[str]]:
     def spec(a2: Var[str]) -> List[Var[str]]:
         return [a2]
         return [a2]
 
 
-    handler = EventHandler(fn=test_fn_with_args)
+    handler = EventHandler(fn=test_fn_with_args, state_full_name="BigState")
     event_spec = handler(make_var("first"))
     event_spec = handler(make_var("first"))
     event_spec2 = call_event_handler(event_spec, spec)
     event_spec2 = call_event_handler(event_spec, spec)
 
 
@@ -115,7 +115,10 @@ def test_call_event_handler_partial():
     assert len(event_spec.args) == 1
     assert len(event_spec.args) == 1
     assert event_spec.args[0][0].equals(Var(_js_expr="arg1"))
     assert event_spec.args[0][0].equals(Var(_js_expr="arg1"))
     assert event_spec.args[0][1].equals(Var(_js_expr="first"))
     assert event_spec.args[0][1].equals(Var(_js_expr="first"))
-    assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})'
+    assert (
+        format.format_event(event_spec)
+        == 'Event("BigState.test_fn_with_args", {arg1:first})'
+    )
 
 
     assert event_spec2 is not event_spec
     assert event_spec2 is not event_spec
     assert event_spec2.handler == handler
     assert event_spec2.handler == handler
@@ -126,7 +129,7 @@ def test_call_event_handler_partial():
     assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str))
     assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str))
     assert (
     assert (
         format.format_event(event_spec2)
         format.format_event(event_spec2)
-        == 'Event("test_fn_with_args", {arg1:first,arg2:_a2})'
+        == 'Event("BigState.test_fn_with_args", {arg1:first,arg2:_a2})'
     )
     )