Переглянути джерело

[REF-3589] raise EventHandlerArgMismatch when event handler args don't match spec (#3853)

* test_component: improve valid/invalid event trigger tests

Add test cases for event triggers defined as annotations.

Add additional cases around lambda returning different values.

Improve assertions for invalid tests (each line needs its own `pytest.raises`).

More invalid test cases.

* [REF-3589] raise EventHandlerArgMismatch when event handler args don't match spec

Improve error message for common issue.

Previously when the event handler arguments didn't match the spec, the
traceback resulted in:

```
OSError: could not get source code
```

Now this problem is traceable as a distinct error condition and users are
empowered to debug their code and reference the documentation (to be updated)
for further information.

* raise EventFnArgMismatch when lambda args don't match event trigger spec

Improve error message for another common issue encountered in the reflex framework.

Previous error message was

```
TypeError: index.<locals>.<lambda>() takes 0 positional arguments but 1 was given
```

* Fix up lambda test cases

* call_event_fn: adjust number of args for bound methods
Masen Furer 9 місяців тому
батько
коміт
356deb5457

+ 1 - 9
reflex/components/component.py

@@ -527,15 +527,7 @@ class Component(BaseComponent, ABC):
             for v in value:
                 if isinstance(v, (EventHandler, EventSpec)):
                     # Call the event handler to get the event.
-                    try:
-                        event = call_event_handler(v, args_spec)
-                    except ValueError as err:
-                        raise ValueError(
-                            f" {err} defined in the `{type(self).__name__}` component"
-                        ) from err
-
-                    # Add the event to the chain.
-                    events.append(event)
+                    events.append(call_event_handler(v, args_spec))
                 elif isinstance(v, Callable):
                     # Call the lambda to get the event chain.
                     result = call_event_fn(v, args_spec)

+ 36 - 29
reflex/event.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import inspect
+import types
 import urllib.parse
 from base64 import b64encode
 from typing import (
@@ -22,6 +23,7 @@ from reflex.ivars.base import ImmutableVar, LiteralVar
 from reflex.ivars.function import FunctionStringVar, FunctionVar
 from reflex.ivars.object import ObjectVar
 from reflex.utils import format
+from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
 from reflex.utils.types import ArgsSpec
 from reflex.vars import ImmutableVarData, Var
 
@@ -831,7 +833,7 @@ def call_event_handler(
         arg_spec: The lambda that define the argument(s) to pass to the event handler.
 
     Raises:
-        ValueError: if number of arguments expected by event_handler doesn't match the spec.
+        EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
 
     Returns:
         The event spec from calling the event handler.
@@ -843,13 +845,16 @@ def call_event_handler(
         return event_handler.add_args(*parsed_args)
 
     args = inspect.getfullargspec(event_handler.fn).args
-    if len(args) == len(["self", *parsed_args]):
+    n_args = len(args) - 1  # subtract 1 for bound self arg
+    if n_args == len(parsed_args):
         return event_handler(*parsed_args)  # type: ignore
     else:
-        source = inspect.getsource(arg_spec)  # type: ignore
-        raise ValueError(
-            f"number of arguments in {event_handler.fn.__qualname__} "
-            f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'"
+        raise EventHandlerArgMismatch(
+            "The number of arguments accepted by "
+            f"{event_handler.fn.__qualname__} ({n_args}) "
+            "does not match the arguments passed by the event trigger: "
+            f"{[str(v) for v in parsed_args]}\n"
+            "See https://reflex.dev/docs/events/event-arguments/"
         )
 
 
@@ -874,58 +879,60 @@ def parse_args_spec(arg_spec: ArgsSpec):
     )
 
 
-def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec] | Var:
+def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     """Call a function to a list of event specs.
 
     The function should return a single EventSpec, a list of EventSpecs, or a
-    single Var. If the function takes in an arg, the arg will be passed to the
-    function. Otherwise, the function will be called with no args.
+    single Var. The function signature must match the passed arg_spec or
+    EventFnArgsMismatch will be raised.
 
     Args:
         fn: The function to call.
-        arg: The argument to pass to the function.
+        arg_spec: The argument spec for the event trigger.
 
     Returns:
         The event specs from calling the function or a Var.
 
     Raises:
-        EventHandlerValueError: If the lambda has an invalid signature.
+        EventFnArgMismatch: If the function signature doesn't match the arg spec.
+        EventHandlerValueError: If the lambda returns an unusable value.
     """
     # Import here to avoid circular imports.
     from reflex.event import EventHandler, EventSpec
     from reflex.utils.exceptions import EventHandlerValueError
 
-    # Get the args of the lambda.
-    args = inspect.getfullargspec(fn).args
+    # Check that fn signature matches arg_spec
+    fn_args = inspect.getfullargspec(fn).args
+    n_fn_args = len(fn_args)
+    if isinstance(fn, types.MethodType):
+        n_fn_args -= 1  # subtract 1 for bound self arg
+    parsed_args = parse_args_spec(arg_spec)
+    if len(parsed_args) != n_fn_args:
+        raise EventFnArgMismatch(
+            "The number of arguments accepted by "
+            f"{fn} ({n_fn_args}) "
+            "does not match the arguments passed by the event trigger: "
+            f"{[str(v) for v in parsed_args]}\n"
+            "See https://reflex.dev/docs/events/event-arguments/"
+        )
 
-    if isinstance(arg, ArgsSpec):
-        out = fn(*parse_args_spec(arg))  # type: ignore
-    else:
-        # Call the lambda.
-        if len(args) == 0:
-            out = fn()
-        elif len(args) == 1:
-            out = fn(arg)
-        else:
-            raise EventHandlerValueError(f"Lambda {fn} must have 0 or 1 arguments.")
+    # Call the function with the parsed args.
+    out = fn(*parsed_args)
 
     # If the function returns a Var, assume it's an EventChain and render it directly.
     if isinstance(out, Var):
         return out
 
     # Convert the output to a list.
-    if not isinstance(out, List):
+    if not isinstance(out, list):
         out = [out]
 
     # Convert any event specs to event specs.
     events = []
     for e in out:
-        # Convert handlers to event specs.
         if isinstance(e, EventHandler):
-            if len(args) == 0:
-                e = e()
-            elif len(args) == 1:
-                e = e(arg)  # type: ignore
+            # An un-called EventHandler gets all of the args of the event trigger.
+            e = call_event_handler(e, arg_spec)
 
         # Make sure the event spec is valid.
         if not isinstance(e, EventSpec):

+ 8 - 0
reflex/utils/exceptions.py

@@ -79,3 +79,11 @@ class LockExpiredError(ReflexError):
 
 class MatchTypeError(ReflexError, TypeError):
     """Raised when the return types of match cases are different."""
+
+
+class EventHandlerArgMismatch(ReflexError, TypeError):
+    """Raised when the number of args accepted by an EventHandler is differs from that provided by the event trigger."""
+
+
+class EventFnArgMismatch(ReflexError, TypeError):
+    """Raised when the number of args accepted by a lambda differs from that provided by the event trigger."""

+ 116 - 11
tests/components/test_component.py

@@ -22,6 +22,7 @@ from reflex.ivars.base import LiteralVar
 from reflex.state import BaseState
 from reflex.style import Style
 from reflex.utils import imports
+from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
 from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.vars import BaseVar, Var, VarData
 
@@ -79,6 +80,8 @@ def component2() -> Type[Component]:
         # A test list prop.
         arr: Var[List[str]]
 
+        on_prop_event: EventHandler[lambda e0: [e0]]
+
         def get_event_triggers(self) -> Dict[str, Any]:
             """Test controlled triggers.
 
@@ -496,7 +499,7 @@ def test_get_props(component1, component2):
         component2: A test component.
     """
     assert component1.get_props() == {"text", "number", "text_or_number"}
-    assert component2.get_props() == {"arr"}
+    assert component2.get_props() == {"arr", "on_prop_event"}
 
 
 @pytest.mark.parametrize(
@@ -574,7 +577,7 @@ def test_get_event_triggers(component1, component2):
     assert component1().get_event_triggers().keys() == default_triggers
     assert (
         component2().get_event_triggers().keys()
-        == {"on_open", "on_close"} | default_triggers
+        == {"on_open", "on_close", "on_prop_event"} | default_triggers
     )
 
 
@@ -888,18 +891,105 @@ def test_invalid_event_handler_args(component2, test_state):
         component2: A test component.
         test_state: A test state.
     """
-    # Uncontrolled event handlers should not take args.
-    # This is okay.
-    component2.create(on_click=test_state.do_something)
-    # This is not okay.
-    with pytest.raises(ValueError):
+    # EventHandler args must match
+    with pytest.raises(EventHandlerArgMismatch):
         component2.create(on_click=test_state.do_something_arg)
+    with pytest.raises(EventHandlerArgMismatch):
         component2.create(on_open=test_state.do_something)
+    with pytest.raises(EventHandlerArgMismatch):
+        component2.create(on_prop_event=test_state.do_something)
+
+    # Multiple EventHandler args: all must match
+    with pytest.raises(EventHandlerArgMismatch):
+        component2.create(
+            on_click=[test_state.do_something_arg, test_state.do_something]
+        )
+    with pytest.raises(EventHandlerArgMismatch):
         component2.create(
             on_open=[test_state.do_something_arg, test_state.do_something]
         )
-    # However lambdas are okay.
+    with pytest.raises(EventHandlerArgMismatch):
+        component2.create(
+            on_prop_event=[test_state.do_something_arg, test_state.do_something]
+        )
+
+    # lambda cannot return weird values.
+    with pytest.raises(ValueError):
+        component2.create(on_click=lambda: 1)
+    with pytest.raises(ValueError):
+        component2.create(on_click=lambda: [1])
+    with pytest.raises(ValueError):
+        component2.create(
+            on_click=lambda: (test_state.do_something_arg(1), test_state.do_something)
+        )
+
+    # lambda signature must match event trigger.
+    with pytest.raises(EventFnArgMismatch):
+        component2.create(on_click=lambda _: test_state.do_something_arg(1))
+    with pytest.raises(EventFnArgMismatch):
+        component2.create(on_open=lambda: test_state.do_something)
+    with pytest.raises(EventFnArgMismatch):
+        component2.create(on_prop_event=lambda: test_state.do_something)
+
+    # lambda returning EventHandler must match spec
+    with pytest.raises(EventHandlerArgMismatch):
+        component2.create(on_click=lambda: test_state.do_something_arg)
+    with pytest.raises(EventHandlerArgMismatch):
+        component2.create(on_open=lambda _: test_state.do_something)
+    with pytest.raises(EventHandlerArgMismatch):
+        component2.create(on_prop_event=lambda _: test_state.do_something)
+
+    # Mixed EventSpec and EventHandler must match spec.
+    with pytest.raises(EventHandlerArgMismatch):
+        component2.create(
+            on_click=lambda: [
+                test_state.do_something_arg(1),
+                test_state.do_something_arg,
+            ]
+        )
+    with pytest.raises(EventHandlerArgMismatch):
+        component2.create(
+            on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
+        )
+    with pytest.raises(EventHandlerArgMismatch):
+        component2.create(
+            on_prop_event=lambda _: [
+                test_state.do_something_arg(1),
+                test_state.do_something,
+            ]
+        )
+
+
+def test_valid_event_handler_args(component2, test_state):
+    """Test that an valid event handler args do not raise exception.
+
+    Args:
+        component2: A test component.
+        test_state: A test state.
+    """
+    # Uncontrolled event handlers should not take args.
+    component2.create(on_click=test_state.do_something)
+    component2.create(on_click=test_state.do_something_arg(1))
+
+    # Controlled event handlers should take args.
+    component2.create(on_open=test_state.do_something_arg)
+    component2.create(on_prop_event=test_state.do_something_arg)
+
+    # Using a partial event spec bypasses arg validation (ignoring the args).
+    component2.create(on_open=test_state.do_something())
+    component2.create(on_prop_event=test_state.do_something())
+
+    # lambda returning EventHandler is okay if the spec matches.
+    component2.create(on_click=lambda: test_state.do_something)
+    component2.create(on_open=lambda _: test_state.do_something_arg)
+    component2.create(on_prop_event=lambda _: test_state.do_something_arg)
+
+    # lambda can always return an EventSpec.
     component2.create(on_click=lambda: test_state.do_something_arg(1))
+    component2.create(on_open=lambda _: test_state.do_something_arg(1))
+    component2.create(on_prop_event=lambda _: test_state.do_something_arg(1))
+
+    # Return EventSpec and EventHandler (no arg).
     component2.create(
         on_click=lambda: [test_state.do_something_arg(1), test_state.do_something]
     )
@@ -907,9 +997,24 @@ def test_invalid_event_handler_args(component2, test_state):
         on_click=lambda: [test_state.do_something_arg(1), test_state.do_something()]
     )
 
-    # Controlled event handlers should take args.
-    # This is okay.
-    component2.create(on_open=test_state.do_something_arg)
+    # Return 2 EventSpec.
+    component2.create(
+        on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something()]
+    )
+    component2.create(
+        on_prop_event=lambda _: [
+            test_state.do_something_arg(1),
+            test_state.do_something(),
+        ]
+    )
+
+    # Return EventHandler (1 arg) and EventSpec.
+    component2.create(
+        on_open=lambda _: [test_state.do_something_arg, test_state.do_something()]
+    )
+    component2.create(
+        on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()]
+    )
 
 
 def test_get_hooks_nested(component1, component2, component3):