Преглед на файлове

allow for event handlers to ignore args (#4282)

* allow for event handlers to ignore args

* use a constant

* dang it darglint

* forgor

* keep the tests but move them to valid place
Khaleel Al-Adhami преди 6 месеца
родител
ревизия
6334cfab0d

+ 12 - 5
reflex/components/component.py

@@ -17,6 +17,7 @@ from typing import (
     Iterator,
     List,
     Optional,
+    Sequence,
     Set,
     Type,
     Union,
@@ -38,6 +39,7 @@ from reflex.constants import (
     PageNames,
 )
 from reflex.constants.compiler import SpecialAttributes
+from reflex.constants.state import FRONTEND_EVENT_STATE
 from reflex.event import (
     EventCallback,
     EventChain,
@@ -533,7 +535,7 @@ class Component(BaseComponent, ABC):
 
     def _create_event_chain(
         self,
-        args_spec: Any,
+        args_spec: types.ArgsSpec | Sequence[types.ArgsSpec],
         value: Union[
             Var,
             EventHandler,
@@ -599,7 +601,7 @@ class Component(BaseComponent, ABC):
 
         # If the input is a callable, create an event chain.
         elif isinstance(value, Callable):
-            result = call_event_fn(value, args_spec)
+            result = call_event_fn(value, args_spec, key=key)
             if isinstance(result, Var):
                 # Recursively call this function if the lambda returned an EventChain Var.
                 return self._create_event_chain(args_spec, result, key=key)
@@ -629,14 +631,16 @@ class Component(BaseComponent, ABC):
                 event_actions={},
             )
 
-    def get_event_triggers(self) -> Dict[str, Any]:
+    def get_event_triggers(
+        self,
+    ) -> Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]]:
         """Get the event triggers for the component.
 
         Returns:
             The event triggers.
 
         """
-        default_triggers = {
+        default_triggers: Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
             EventTriggers.ON_FOCUS: no_args_event_spec,
             EventTriggers.ON_BLUR: no_args_event_spec,
             EventTriggers.ON_CLICK: no_args_event_spec,
@@ -1142,7 +1146,10 @@ class Component(BaseComponent, ABC):
                     if isinstance(event, EventCallback):
                         continue
                     if isinstance(event, EventSpec):
-                        if event.handler.state_full_name:
+                        if (
+                            event.handler.state_full_name
+                            and event.handler.state_full_name != FRONTEND_EVENT_STATE
+                        ):
                             return True
                     else:
                         if event._var_state:

+ 4 - 0
reflex/constants/state.py

@@ -9,3 +9,7 @@ class StateManagerMode(str, Enum):
     DISK = "disk"
     MEMORY = "memory"
     REDIS = "redis"
+
+
+# Used for things like console_log, etc.
+FRONTEND_EVENT_STATE = "__reflex_internal_frontend_event_state"

+ 73 - 64
reflex/event.py

@@ -28,10 +28,10 @@ from typing import (
 from typing_extensions import ParamSpec, Protocol, get_args, get_origin
 
 from reflex import constants
+from reflex.constants.state import FRONTEND_EVENT_STATE
 from reflex.utils import console, format
 from reflex.utils.exceptions import (
     EventFnArgMismatch,
-    EventHandlerArgMismatch,
     EventHandlerArgTypeMismatch,
 )
 from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
@@ -662,7 +662,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
     fn.__qualname__ = name
     fn.__signature__ = sig
     return EventSpec(
-        handler=EventHandler(fn=fn),
+        handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
         args=tuple(
             (
                 Var(_js_expr=k),
@@ -1092,8 +1092,8 @@ def get_hydrate_event(state) -> str:
 
 
 def call_event_handler(
-    event_handler: EventHandler | EventSpec,
-    arg_spec: ArgsSpec | Sequence[ArgsSpec],
+    event_callback: EventHandler | EventSpec,
+    event_spec: ArgsSpec | Sequence[ArgsSpec],
     key: Optional[str] = None,
 ) -> EventSpec:
     """Call an event handler to get the event spec.
@@ -1103,53 +1103,57 @@ def call_event_handler(
     Otherwise, the event handler will be called with no args.
 
     Args:
-        event_handler: The event handler.
-        arg_spec: The lambda that define the argument(s) to pass to the event handler.
+        event_callback: The event handler.
+        event_spec: The lambda that define the argument(s) to pass to the event handler.
         key: The key to pass to the event handler.
 
-    Raises:
-        EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
-
     Returns:
         The event spec from calling the event handler.
 
     # noqa: DAR401 failure
 
     """
-    parsed_args = parse_args_spec(arg_spec)  # type: ignore
-
-    if isinstance(event_handler, EventSpec):
-        # Handle partial application of EventSpec args
-        return event_handler.add_args(*parsed_args)
-
-    provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)
-
-    provided_callback_n_args = (
-        len(provided_callback_fullspec.args) - 1
-    )  # subtract 1 for bound self arg
-
-    if provided_callback_n_args != len(parsed_args):
-        raise EventHandlerArgMismatch(
-            "The number of arguments accepted by "
-            f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
-            "does not match the arguments passed by the event trigger: "
-            f"{[str(v) for v in parsed_args]}\n"
-            "See https://reflex.dev/docs/events/event-arguments/"
+    event_spec_args = parse_args_spec(event_spec)  # type: ignore
+
+    if isinstance(event_callback, EventSpec):
+        check_fn_match_arg_spec(
+            event_callback.handler.fn,
+            event_spec,
+            key,
+            bool(event_callback.handler.state_full_name) + len(event_callback.args),
+            event_callback.handler.fn.__qualname__,
         )
+        # Handle partial application of EventSpec args
+        return event_callback.add_args(*event_spec_args)
+
+    check_fn_match_arg_spec(
+        event_callback.fn,
+        event_spec,
+        key,
+        bool(event_callback.state_full_name),
+        event_callback.fn.__qualname__,
+    )
 
-    all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
+    all_acceptable_specs = (
+        [event_spec] if not isinstance(event_spec, Sequence) else event_spec
+    )
 
     event_spec_return_types = list(
         filter(
             lambda event_spec_return_type: event_spec_return_type is not None
             and get_origin(event_spec_return_type) is tuple,
-            (get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
+            (
+                get_type_hints(arg_spec).get("return", None)
+                for arg_spec in all_acceptable_specs
+            ),
         )
     )
 
     if event_spec_return_types:
         failures = []
 
+        event_callback_spec = inspect.getfullargspec(event_callback.fn)
+
         for event_spec_index, event_spec_return_type in enumerate(
             event_spec_return_types
         ):
@@ -1160,14 +1164,14 @@ def call_event_handler(
             ]
 
             try:
-                type_hints_of_provided_callback = get_type_hints(event_handler.fn)
+                type_hints_of_provided_callback = get_type_hints(event_callback.fn)
             except NameError:
                 type_hints_of_provided_callback = {}
 
             failed_type_check = False
 
             # check that args of event handler are matching the spec if type hints are provided
-            for i, arg in enumerate(provided_callback_fullspec.args[1:]):
+            for i, arg in enumerate(event_callback_spec.args[1:]):
                 if arg not in type_hints_of_provided_callback:
                     continue
 
@@ -1181,7 +1185,7 @@ def call_event_handler(
                     #     f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
                     # ) from e
                     console.warn(
-                        f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
+                        f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}."
                     )
                     compare_result = False
 
@@ -1189,7 +1193,7 @@ def call_event_handler(
                     continue
                 else:
                     failure = EventHandlerArgTypeMismatch(
-                        f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
+                        f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
                     )
                     failures.append(failure)
                     failed_type_check = True
@@ -1210,14 +1214,14 @@ def call_event_handler(
 
                     given_string = ", ".join(
                         repr(type_hints_of_provided_callback.get(arg, Any))
-                        for arg in provided_callback_fullspec.args[1:]
+                        for arg in event_callback_spec.args[1:]
                     ).replace("[", "\\[")
 
                     console.warn(
-                        f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
+                        f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. "
                         f"This may lead to unexpected behavior but is intentionally ignored for {key}."
                     )
-                return event_handler(*parsed_args)
+                return event_callback(*event_spec_args)
 
         if failures:
             console.deprecate(
@@ -1227,7 +1231,7 @@ def call_event_handler(
                 "0.7.0",
             )
 
-    return event_handler(*parsed_args)  # type: ignore
+    return event_callback(*event_spec_args)  # type: ignore
 
 
 def unwrap_var_annotation(annotation: GenericType):
@@ -1294,45 +1298,46 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
 
 
 def check_fn_match_arg_spec(
-    fn: Callable,
-    arg_spec: ArgsSpec,
-    key: Optional[str] = None,
-) -> List[Var]:
+    user_func: Callable,
+    arg_spec: ArgsSpec | Sequence[ArgsSpec],
+    key: str | None = None,
+    number_of_bound_args: int = 0,
+    func_name: str | None = None,
+):
     """Ensures that the function signature matches the passed argument specification
     or raises an EventFnArgMismatch if they do not.
 
     Args:
-        fn: The function to be validated.
+        user_func: The function to be validated.
         arg_spec: The argument specification for the event trigger.
-        key: The key to pass to the event handler.
-
-    Returns:
-        The parsed arguments from the argument specification.
+        key: The key of the event trigger.
+        number_of_bound_args: The number of bound arguments to the function.
+        func_name: The name of the function to be validated.
 
     Raises:
         EventFnArgMismatch: Raised if the number of mandatory arguments do not match
     """
-    fn_args = inspect.getfullargspec(fn).args
-    fn_defaults_args = inspect.getfullargspec(fn).defaults
-    n_fn_args = len(fn_args)
-    n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
-    if isinstance(fn, types.MethodType):
-        n_fn_args -= 1  # subtract 1 for bound self arg
-    parsed_args = parse_args_spec(arg_spec)
-    if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
+    user_args = inspect.getfullargspec(user_func).args
+    user_default_args = inspect.getfullargspec(user_func).defaults
+    number_of_user_args = len(user_args) - number_of_bound_args
+    number_of_user_default_args = len(user_default_args) if user_default_args else 0
+
+    parsed_event_args = parse_args_spec(arg_spec)
+
+    number_of_event_args = len(parsed_event_args)
+
+    if number_of_user_args - number_of_user_default_args > number_of_event_args:
         raise EventFnArgMismatch(
-            "The number of mandatory arguments accepted by "
-            f"{fn} ({n_fn_args - n_fn_defaults_args}) "
-            "does not match the arguments passed by the event trigger: "
-            f"{[str(v) for v in parsed_args]}\n"
+            f"Event {key} only provides {number_of_event_args} arguments, but "
+            f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} "
+            "arguments to be passed to the event handler.\n"
             "See https://reflex.dev/docs/events/event-arguments/"
         )
-    return parsed_args
 
 
 def call_event_fn(
     fn: Callable,
-    arg_spec: ArgsSpec,
+    arg_spec: ArgsSpec | Sequence[ArgsSpec],
     key: Optional[str] = None,
 ) -> list[EventSpec] | Var:
     """Call a function to a list of event specs.
@@ -1356,10 +1361,14 @@ def call_event_fn(
     from reflex.utils.exceptions import EventHandlerValueError
 
     # Check that fn signature matches arg_spec
-    parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
+    check_fn_match_arg_spec(fn, arg_spec, key=key)
+
+    parsed_args = parse_args_spec(arg_spec)
+
+    number_of_fn_args = len(inspect.getfullargspec(fn).args)
 
     # Call the function with the parsed args.
-    out = fn(*parsed_args)
+    out = fn(*[*parsed_args][:number_of_fn_args])
 
     # If the function returns a Var, assume it's an EventChain and render it directly.
     if isinstance(out, Var):
@@ -1478,7 +1487,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
     """
     signature = inspect.signature(fn)
     new_param = inspect.Parameter(
-        "state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
+        FRONTEND_EVENT_STATE, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
     )
     return signature.replace(parameters=(new_param, *signature.parameters.values()))
 

+ 1 - 5
reflex/utils/exceptions.py

@@ -89,16 +89,12 @@ class MatchTypeError(ReflexError, TypeError):
     """Raised when the return types of match cases are different."""
 
 
-class EventHandlerArgMismatch(ReflexError, TypeError):
-    """Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""
-
-
 class EventHandlerArgTypeMismatch(ReflexError, TypeError):
     """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
 
 
 class EventFnArgMismatch(ReflexError, TypeError):
-    """Raised when the number of args accepted by a lambda differs from that provided by the event trigger."""
+    """Raised when the number of args required by an event handler is more than provided by the event trigger."""
 
 
 class DynamicRouteArgShadowsStateVar(ReflexError, NameError):

+ 2 - 1
reflex/utils/format.py

@@ -9,6 +9,7 @@ import re
 from typing import TYPE_CHECKING, Any, List, Optional, Union
 
 from reflex import constants
+from reflex.constants.state import FRONTEND_EVENT_STATE
 from reflex.utils import exceptions
 from reflex.utils.console import deprecate
 
@@ -439,7 +440,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
 
     from reflex.state import State
 
-    if state_full_name == "state" and name not in State.__dict__:
+    if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__:
         return ("", to_snake_case(handler.fn.__qualname__))
 
     return (state_full_name, name)

+ 2 - 2
reflex/utils/pyi_generator.py

@@ -16,7 +16,7 @@ from itertools import chain
 from multiprocessing import Pool, cpu_count
 from pathlib import Path
 from types import ModuleType, SimpleNamespace
-from typing import Any, Callable, Iterable, Type, get_args, get_origin
+from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
 
 from reflex.components.component import Component
 from reflex.utils import types as rx_types
@@ -560,7 +560,7 @@ def _generate_component_create_functiondef(
                                     inspect.signature(event_specs).return_annotation
                                 )
                                 if not isinstance(
-                                    event_specs := event_triggers[trigger], tuple
+                                    event_specs := event_triggers[trigger], Sequence
                                 )
                                 else ast.Subscript(
                                     ast.Name("Union"),

+ 27 - 36
tests/units/components/test_component.py

@@ -29,7 +29,6 @@ from reflex.style import Style
 from reflex.utils import imports
 from reflex.utils.exceptions import (
     EventFnArgMismatch,
-    EventHandlerArgMismatch,
 )
 from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.vars import VarData
@@ -907,26 +906,14 @@ def test_invalid_event_handler_args(component2, test_state):
         test_state: A test state.
     """
     # EventHandler args must match
-    with pytest.raises(EventHandlerArgMismatch):
+    with pytest.raises(EventFnArgMismatch):
         component2.create(on_click=test_state.do_something_arg)
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(on_open=test_state.do_something)
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(on_prop_event=test_state.do_something)
 
     # Multiple EventHandler args: all must match
-    with pytest.raises(EventHandlerArgMismatch):
+    with pytest.raises(EventFnArgMismatch):
         component2.create(
             on_click=[test_state.do_something_arg, test_state.do_something]
         )
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(
-            on_open=[test_state.do_something_arg, test_state.do_something]
-        )
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(
-            on_prop_event=[test_state.do_something_arg, test_state.do_something]
-        )
 
     # Enable when 0.7.0 happens
     # # Event Handler types must match
@@ -957,38 +944,19 @@ def test_invalid_event_handler_args(component2, test_state):
     # lambda signature must match event trigger.
     with pytest.raises(EventFnArgMismatch):
         component2.create(on_click=lambda _: test_state.do_something_arg(1))
-    with pytest.raises(EventFnArgMismatch):
-        component2.create(on_open=lambda: test_state.do_something)
-    with pytest.raises(EventFnArgMismatch):
-        component2.create(on_prop_event=lambda: test_state.do_something)
 
     # lambda returning EventHandler must match spec
-    with pytest.raises(EventHandlerArgMismatch):
+    with pytest.raises(EventFnArgMismatch):
         component2.create(on_click=lambda: test_state.do_something_arg)
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(on_open=lambda _: test_state.do_something)
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(on_prop_event=lambda _: test_state.do_something)
 
     # Mixed EventSpec and EventHandler must match spec.
-    with pytest.raises(EventHandlerArgMismatch):
+    with pytest.raises(EventFnArgMismatch):
         component2.create(
             on_click=lambda: [
                 test_state.do_something_arg(1),
                 test_state.do_something_arg,
             ]
         )
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(
-            on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
-        )
-    with pytest.raises(EventHandlerArgMismatch):
-        component2.create(
-            on_prop_event=lambda _: [
-                test_state.do_something_arg(1),
-                test_state.do_something,
-            ]
-        )
 
 
 def test_valid_event_handler_args(component2, test_state):
@@ -1002,6 +970,10 @@ def test_valid_event_handler_args(component2, test_state):
     component2.create(on_click=test_state.do_something)
     component2.create(on_click=test_state.do_something_arg(1))
 
+    # Does not raise because event handlers are allowed to have less args than the spec.
+    component2.create(on_open=test_state.do_something)
+    component2.create(on_prop_event=test_state.do_something)
+
     # Controlled event handlers should take args.
     component2.create(on_open=test_state.do_something_arg)
     component2.create(on_prop_event=test_state.do_something_arg)
@@ -1010,10 +982,20 @@ def test_valid_event_handler_args(component2, test_state):
     component2.create(on_open=test_state.do_something())
     component2.create(on_prop_event=test_state.do_something())
 
+    # Multiple EventHandler args: all must match
+    component2.create(on_open=[test_state.do_something_arg, test_state.do_something])
+    component2.create(
+        on_prop_event=[test_state.do_something_arg, test_state.do_something]
+    )
+
     # lambda returning EventHandler is okay if the spec matches.
     component2.create(on_click=lambda: test_state.do_something)
     component2.create(on_open=lambda _: test_state.do_something_arg)
     component2.create(on_prop_event=lambda _: test_state.do_something_arg)
+    component2.create(on_open=lambda: test_state.do_something)
+    component2.create(on_prop_event=lambda: test_state.do_something)
+    component2.create(on_open=lambda _: test_state.do_something)
+    component2.create(on_prop_event=lambda _: test_state.do_something)
 
     # lambda can always return an EventSpec.
     component2.create(on_click=lambda: test_state.do_something_arg(1))
@@ -1046,6 +1028,15 @@ def test_valid_event_handler_args(component2, test_state):
     component2.create(
         on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()]
     )
+    component2.create(
+        on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
+    )
+    component2.create(
+        on_prop_event=lambda _: [
+            test_state.do_something_arg(1),
+            test_state.do_something,
+        ]
+    )
 
 
 def test_get_hooks_nested(component1, component2, component3):

+ 6 - 3
tests/units/test_event.py

@@ -107,7 +107,7 @@ def test_call_event_handler_partial():
     def spec(a2: Var[str]) -> List[Var[str]]:
         return [a2]
 
-    handler = EventHandler(fn=test_fn_with_args)
+    handler = EventHandler(fn=test_fn_with_args, state_full_name="BigState")
     event_spec = handler(make_var("first"))
     event_spec2 = call_event_handler(event_spec, spec)
 
@@ -115,7 +115,10 @@ def test_call_event_handler_partial():
     assert len(event_spec.args) == 1
     assert event_spec.args[0][0].equals(Var(_js_expr="arg1"))
     assert event_spec.args[0][1].equals(Var(_js_expr="first"))
-    assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})'
+    assert (
+        format.format_event(event_spec)
+        == 'Event("BigState.test_fn_with_args", {arg1:first})'
+    )
 
     assert event_spec2 is not event_spec
     assert event_spec2.handler == handler
@@ -126,7 +129,7 @@ def test_call_event_handler_partial():
     assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str))
     assert (
         format.format_event(event_spec2)
-        == 'Event("test_fn_with_args", {arg1:first,arg2:_a2})'
+        == 'Event("BigState.test_fn_with_args", {arg1:first,arg2:_a2})'
     )