1
0
Эх сурвалжийг харах

add type checking for partially filled events

Khaleel Al-Adhami 6 өдөр өмнө
parent
commit
63c7888796

+ 152 - 93
reflex/event.py

@@ -1262,6 +1262,118 @@ def get_hydrate_event(state: BaseState) -> str:
     return get_event(state, constants.CompileVars.HYDRATE)
     return get_event(state, constants.CompileVars.HYDRATE)
 
 
 
 
+def _values_returned_from_event(
+    event_spec: ArgsSpec | Sequence[ArgsSpec],
+) -> list[Any]:
+    return [
+        event_spec_return_type
+        for arg_spec in (
+            [event_spec] if not isinstance(event_spec, Sequence) else list(event_spec)
+        )
+        if (event_spec_return_type := get_type_hints(arg_spec).get("return", None))
+        is not None
+        and get_origin(event_spec_return_type) is tuple
+    ]
+
+
+def _check_event_args_subclass_of_callback(
+    callback_params_names: list[str],
+    provided_event_types: list[Any],
+    callback_param_name_to_type: dict[str, Any],
+    callback_name: str = "",
+    key: str = "",
+):
+    """Check if the event handler arguments are subclass of the callback.
+
+    Args:
+        callback_params_names: The names of the callback parameters.
+        provided_event_types: The event types.
+        callback_param_name_to_type: The callback parameter name to type mapping.
+        callback_name: The name of the callback.
+        key: The key.
+
+    Raises:
+        TypeError: If the event handler arguments are invalid.
+        EventHandlerArgTypeMismatchError: If the event handler arguments do not match the callback.
+
+    # noqa: DAR401 delayed_exceptions[]
+    # noqa: DAR402 EventHandlerArgTypeMismatchError
+    """
+    type_match_found: dict[str, bool] = {}
+    delayed_exceptions: list[EventHandlerArgTypeMismatchError] = []
+
+    for event_spec_index, event_spec_return_type in enumerate(provided_event_types):
+        args = get_args(event_spec_return_type)
+
+        args_types_without_vars = [
+            arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
+        ]
+
+        # check that args of event handler are matching the spec if type hints are provided
+        for i, arg in enumerate(callback_params_names[: len(args_types_without_vars)]):
+            if arg not in callback_param_name_to_type:
+                continue
+
+            type_match_found.setdefault(arg, False)
+
+            try:
+                compare_result = typehint_issubclass(
+                    args_types_without_vars[i], callback_param_name_to_type[arg]
+                )
+            except TypeError as te:
+                callback_name_context = f" of {callback_name}" if callback_name else ""
+                key_context = f" for {key}" if key else ""
+                raise TypeError(
+                    f"Could not compare types {args_types_without_vars[i]} and {callback_param_name_to_type[arg]} for argument {arg}{callback_name_context}{key_context}."
+                ) from te
+
+            if compare_result:
+                type_match_found[arg] = True
+                continue
+            else:
+                type_match_found[arg] = False
+                as_annotated_in = (
+                    f" as annotated in {callback_name}" if callback_name else ""
+                )
+                delayed_exceptions.append(
+                    EventHandlerArgTypeMismatchError(
+                        f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {callback_param_name_to_type[arg]}{as_annotated_in} instead."
+                    )
+                )
+
+        if all(type_match_found.values()):
+            delayed_exceptions.clear()
+            if event_spec_index:
+                args = get_args(provided_event_types[0])
+
+                args_types_without_vars = [
+                    arg if get_origin(arg) is not Var else get_args(arg)[0]
+                    for arg in args
+                ]
+
+                expect_string = ", ".join(
+                    repr(arg) for arg in args_types_without_vars
+                ).replace("[", "\\[")
+
+                given_string = ", ".join(
+                    repr(callback_param_name_to_type.get(arg, Any))
+                    for arg in callback_params_names
+                ).replace("[", "\\[")
+
+                as_annotated_in = (
+                    f" as annotated in {callback_name}" if callback_name else ""
+                )
+
+                console.warn(
+                    f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> (){as_annotated_in} instead. "
+                    f"This may lead to unexpected behavior but is intentionally ignored for {key}."
+                )
+            break
+
+    if delayed_exceptions:
+        raise delayed_exceptions[0]
+
+
 def call_event_handler(
 def call_event_handler(
     event_callback: EventHandler | EventSpec,
     event_callback: EventHandler | EventSpec,
     event_spec: ArgsSpec | Sequence[ArgsSpec],
     event_spec: ArgsSpec | Sequence[ArgsSpec],
@@ -1278,17 +1390,13 @@ def call_event_handler(
         event_spec: The lambda that define the argument(s) to pass to the event handler.
         event_spec: The lambda that define the argument(s) to pass to the event handler.
         key: The key to pass to the event handler.
         key: The key to pass to the event handler.
 
 
-    Raises:
-        EventHandlerArgTypeMismatchError: If the event handler arguments do not match the event spec. #noqa: DAR402
-        TypeError: If the event handler arguments are invalid.
-
     Returns:
     Returns:
         The event spec from calling the event handler.
         The event spec from calling the event handler.
-
-    #noqa: DAR401
     """
     """
     event_spec_args = parse_args_spec(event_spec)
     event_spec_args = parse_args_spec(event_spec)
 
 
+    event_spec_return_types = _values_returned_from_event(event_spec)
+
     if isinstance(event_callback, EventSpec):
     if isinstance(event_callback, EventSpec):
         check_fn_match_arg_spec(
         check_fn_match_arg_spec(
             event_callback.handler.fn,
             event_callback.handler.fn,
@@ -1297,6 +1405,32 @@ def call_event_handler(
             bool(event_callback.handler.state_full_name) + len(event_callback.args),
             bool(event_callback.handler.state_full_name) + len(event_callback.args),
             event_callback.handler.fn.__qualname__,
             event_callback.handler.fn.__qualname__,
         )
         )
+
+        event_callback_spec_args = list(
+            inspect.signature(event_callback.handler.fn).parameters.keys()
+        )
+
+        try:
+            type_hints_of_provided_callback = get_type_hints(event_callback.handler.fn)
+        except NameError:
+            type_hints_of_provided_callback = {}
+
+        argument_names = [str(arg) for arg, value in event_callback.args]
+
+        _check_event_args_subclass_of_callback(
+            [
+                arg
+                for arg in event_callback_spec_args[
+                    bool(event_callback.handler.state_full_name) :
+                ]
+                if arg not in argument_names
+            ],
+            event_spec_return_types,
+            type_hints_of_provided_callback,
+            event_callback.handler.fn.__qualname__,
+            key or "",
+        )
+
         # Handle partial application of EventSpec args
         # Handle partial application of EventSpec args
         return event_callback.add_args(*event_spec_args)
         return event_callback.add_args(*event_spec_args)
 
 
@@ -1308,98 +1442,23 @@ def call_event_handler(
         event_callback.fn.__qualname__,
         event_callback.fn.__qualname__,
     )
     )
 
 
-    all_acceptable_specs = (
-        [event_spec] if not isinstance(event_spec, Sequence) else event_spec
-    )
-
-    event_spec_return_types = list(
-        filter(
-            lambda event_spec_return_type: event_spec_return_type is not None
-            and get_origin(event_spec_return_type) is tuple,
-            (
-                get_type_hints(arg_spec).get("return", None)
-                for arg_spec in all_acceptable_specs
-            ),
-        )
-    )
-    type_match_found: dict[str, bool] = {}
-    delayed_exceptions: list[EventHandlerArgTypeMismatchError] = []
-
-    try:
-        type_hints_of_provided_callback = get_type_hints(event_callback.fn)
-    except NameError:
-        type_hints_of_provided_callback = {}
-
     if event_spec_return_types:
     if event_spec_return_types:
         event_callback_spec_args = list(
         event_callback_spec_args = list(
             inspect.signature(event_callback.fn).parameters.keys()
             inspect.signature(event_callback.fn).parameters.keys()
         )
         )
 
 
-        for event_spec_index, event_spec_return_type in enumerate(
-            event_spec_return_types
-        ):
-            args = get_args(event_spec_return_type)
-
-            args_types_without_vars = [
-                arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
-            ]
-
-            # check that args of event handler are matching the spec if type hints are provided
-            for i, arg in enumerate(
-                event_callback_spec_args[1 : len(args_types_without_vars) + 1]
-            ):
-                if arg not in type_hints_of_provided_callback:
-                    continue
-
-                type_match_found.setdefault(arg, False)
-
-                try:
-                    compare_result = typehint_issubclass(
-                        args_types_without_vars[i], type_hints_of_provided_callback[arg]
-                    )
-                except TypeError as te:
-                    raise TypeError(
-                        f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}."
-                    ) from te
-
-                if compare_result:
-                    type_match_found[arg] = True
-                    continue
-                else:
-                    type_match_found[arg] = False
-                    delayed_exceptions.append(
-                        EventHandlerArgTypeMismatchError(
-                            f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
-                        )
-                    )
-
-            if all(type_match_found.values()):
-                delayed_exceptions.clear()
-                if event_spec_index:
-                    args = get_args(event_spec_return_types[0])
-
-                    args_types_without_vars = [
-                        arg if get_origin(arg) is not Var else get_args(arg)[0]
-                        for arg in args
-                    ]
-
-                    expect_string = ", ".join(
-                        repr(arg) for arg in args_types_without_vars
-                    ).replace("[", "\\[")
-
-                    given_string = ", ".join(
-                        repr(type_hints_of_provided_callback.get(arg, Any))
-                        for arg in event_callback_spec_args[1:]
-                    ).replace("[", "\\[")
-
-                    console.warn(
-                        f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. "
-                        f"This may lead to unexpected behavior but is intentionally ignored for {key}."
-                    )
-                break
-
-    if delayed_exceptions:
-        raise delayed_exceptions[0]
+        try:
+            type_hints_of_provided_callback = get_type_hints(event_callback.fn)
+        except NameError:
+            type_hints_of_provided_callback = {}
+
+        _check_event_args_subclass_of_callback(
+            event_callback_spec_args[1:],
+            event_spec_return_types,
+            type_hints_of_provided_callback,
+            event_callback.fn.__qualname__,
+            key or "",
+        )
 
 
     return event_callback(*event_spec_args)
     return event_callback(*event_spec_args)
 
 

+ 16 - 0
tests/units/components/test_component.py

@@ -945,6 +945,22 @@ def test_invalid_event_handler_args(component2, test_state):
         component2.create(on_user_list_changed=test_state.do_something_with_int)
         component2.create(on_user_list_changed=test_state.do_something_with_int)
     with pytest.raises(EventHandlerArgTypeMismatchError):
     with pytest.raises(EventHandlerArgTypeMismatchError):
         component2.create(on_user_list_changed=test_state.do_something_with_list_int)
         component2.create(on_user_list_changed=test_state.do_something_with_list_int)
+    with pytest.raises(EventHandlerArgTypeMismatchError):
+        component2.create(
+            on_user_visited_count_changed=test_state.do_something_with_bool()
+        )
+    with pytest.raises(EventHandlerArgTypeMismatchError):
+        component2.create(on_user_list_changed=test_state.do_something_with_int())
+    with pytest.raises(EventHandlerArgTypeMismatchError):
+        component2.create(on_user_list_changed=test_state.do_something_with_list_int())
+
+    component2.create(
+        on_user_visited_count_changed=test_state.do_something_with_bool(False)
+    )
+    component2.create(on_user_list_changed=test_state.do_something_with_int(23))
+    component2.create(
+        on_user_list_changed=test_state.do_something_with_list_int([2321, 321])
+    )
 
 
     component2.create(on_open=test_state.do_something_with_int)
     component2.create(on_open=test_state.do_something_with_int)
     component2.create(on_open=test_state.do_something_with_bool)
     component2.create(on_open=test_state.do_something_with_bool)