Przeglądaj źródła

[ENG-3943]type check for event handler if spec arg are typed (#4046)

* type check for event handler if spec arg are typed

* fix the typecheck logic

* rearrange logic pieces

* add try except

* add try except around compare

* change form and improve type checking

* print key instead

* dang it darglint

* change wording

* add basic test to cover it

* add a slightly more complicated test

* challenge it a bit by doing small capital list

* add multiple argspec

* fix slider event order

* i hate 3.9

* add note for UnionType

* move function to types

* add a test for type hint is subclass

* make on submit dict str any

* add testing for dict cases

* add check against any

* accept dict str str

* bruh i used i twice

* escape strings and print actual error message

* disable the error and print deprecation warning instead

* disable tests

* fix doc message

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
Thomas Brandého 6 miesięcy temu
rodzic
commit
c07eb2a6a0

+ 8 - 4
reflex/components/component.py

@@ -480,6 +480,7 @@ class Component(BaseComponent, ABC):
                 kwargs["event_triggers"][key] = self._create_event_chain(
                 kwargs["event_triggers"][key] = self._create_event_chain(
                     value=value,  # type: ignore
                     value=value,  # type: ignore
                     args_spec=component_specific_triggers[key],
                     args_spec=component_specific_triggers[key],
+                    key=key,
                 )
                 )
 
 
         # Remove any keys that were added as events.
         # Remove any keys that were added as events.
@@ -540,12 +541,14 @@ class Component(BaseComponent, ABC):
             List[Union[EventHandler, EventSpec, EventVar]],
             List[Union[EventHandler, EventSpec, EventVar]],
             Callable,
             Callable,
         ],
         ],
+        key: Optional[str] = None,
     ) -> Union[EventChain, Var]:
     ) -> Union[EventChain, Var]:
         """Create an event chain from a variety of input types.
         """Create an event chain from a variety of input types.
 
 
         Args:
         Args:
             args_spec: The args_spec of the event trigger being bound.
             args_spec: The args_spec of the event trigger being bound.
             value: The value to create the event chain from.
             value: The value to create the event chain from.
+            key: The key of the event trigger being bound.
 
 
         Returns:
         Returns:
             The event chain.
             The event chain.
@@ -560,7 +563,7 @@ class Component(BaseComponent, ABC):
             elif isinstance(value, EventVar):
             elif isinstance(value, EventVar):
                 value = [value]
                 value = [value]
             elif issubclass(value._var_type, (EventChain, EventSpec)):
             elif issubclass(value._var_type, (EventChain, EventSpec)):
-                return self._create_event_chain(args_spec, value.guess_type())
+                return self._create_event_chain(args_spec, value.guess_type(), key=key)
             else:
             else:
                 raise ValueError(
                 raise ValueError(
                     f"Invalid event chain: {str(value)} of type {value._var_type}"
                     f"Invalid event chain: {str(value)} of type {value._var_type}"
@@ -579,10 +582,10 @@ class Component(BaseComponent, ABC):
             for v in value:
             for v in value:
                 if isinstance(v, (EventHandler, EventSpec)):
                 if isinstance(v, (EventHandler, EventSpec)):
                     # Call the event handler to get the event.
                     # Call the event handler to get the event.
-                    events.append(call_event_handler(v, args_spec))
+                    events.append(call_event_handler(v, args_spec, key=key))
                 elif isinstance(v, Callable):
                 elif isinstance(v, Callable):
                     # Call the lambda to get the event chain.
                     # Call the lambda to get the event chain.
-                    result = call_event_fn(v, args_spec)
+                    result = call_event_fn(v, args_spec, key=key)
                     if isinstance(result, Var):
                     if isinstance(result, Var):
                         raise ValueError(
                         raise ValueError(
                             f"Invalid event chain: {v}. Cannot use a Var-returning "
                             f"Invalid event chain: {v}. Cannot use a Var-returning "
@@ -599,7 +602,7 @@ class Component(BaseComponent, ABC):
             result = call_event_fn(value, args_spec)
             result = call_event_fn(value, args_spec)
             if isinstance(result, Var):
             if isinstance(result, Var):
                 # Recursively call this function if the lambda returned an EventChain Var.
                 # Recursively call this function if the lambda returned an EventChain Var.
-                return self._create_event_chain(args_spec, result)
+                return self._create_event_chain(args_spec, result, key=key)
             events = [*result]
             events = [*result]
 
 
         # Otherwise, raise an error.
         # Otherwise, raise an error.
@@ -1722,6 +1725,7 @@ class CustomComponent(Component):
                     args_spec=event_triggers_in_component_declaration.get(
                     args_spec=event_triggers_in_component_declaration.get(
                         key, empty_event
                         key, empty_event
                     ),
                     ),
+                    key=key,
                 )
                 )
                 self.props[format.to_camel_case(key)] = value
                 self.props[format.to_camel_case(key)] = value
                 continue
                 continue

+ 10 - 1
reflex/components/el/elements/forms.py

@@ -111,6 +111,15 @@ def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]:
     return (FORM_DATA,)
     return (FORM_DATA,)
 
 
 
 
+def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]:
+    """Event handler spec for the on_submit event.
+
+    Returns:
+        The event handler spec.
+    """
+    return (FORM_DATA,)
+
+
 class Form(BaseHTML):
 class Form(BaseHTML):
     """Display the form element."""
     """Display the form element."""
 
 
@@ -150,7 +159,7 @@ class Form(BaseHTML):
     handle_submit_unique_name: Var[str]
     handle_submit_unique_name: Var[str]
 
 
     # Fired when the form is submitted
     # Fired when the form is submitted
-    on_submit: EventHandler[on_submit_event_spec]
+    on_submit: EventHandler[on_submit_event_spec, on_submit_string_event_spec]
 
 
     @classmethod
     @classmethod
     def create(cls, *children, **props):
     def create(cls, *children, **props):

+ 4 - 1
reflex/components/el/elements/forms.pyi

@@ -271,6 +271,7 @@ class Fieldset(Element):
         ...
         ...
 
 
 def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]: ...
 def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]: ...
+def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]: ...
 
 
 class Form(BaseHTML):
 class Form(BaseHTML):
     @overload
     @overload
@@ -337,7 +338,9 @@ class Form(BaseHTML):
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
-        on_submit: Optional[EventType[Dict[str, Any]]] = None,
+        on_submit: Optional[
+            Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
+        ] = None,
         on_unmount: Optional[EventType[[]]] = None,
         on_unmount: Optional[EventType[[]]] = None,
         **props,
         **props,
     ) -> "Form":
     ) -> "Form":

+ 9 - 3
reflex/components/radix/primitives/form.pyi

@@ -129,7 +129,9 @@ class FormRoot(FormComponent, HTMLForm):
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
-        on_submit: Optional[EventType[Dict[str, Any]]] = None,
+        on_submit: Optional[
+            Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
+        ] = None,
         on_unmount: Optional[EventType[[]]] = None,
         on_unmount: Optional[EventType[[]]] = None,
         **props,
         **props,
     ) -> "FormRoot":
     ) -> "FormRoot":
@@ -596,7 +598,9 @@ class Form(FormRoot):
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
-        on_submit: Optional[EventType[Dict[str, Any]]] = None,
+        on_submit: Optional[
+            Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
+        ] = None,
         on_unmount: Optional[EventType[[]]] = None,
         on_unmount: Optional[EventType[[]]] = None,
         **props,
         **props,
     ) -> "Form":
     ) -> "Form":
@@ -720,7 +724,9 @@ class FormNamespace(ComponentNamespace):
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
-        on_submit: Optional[EventType[Dict[str, Any]]] = None,
+        on_submit: Optional[
+            Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
+        ] = None,
         on_unmount: Optional[EventType[[]]] = None,
         on_unmount: Optional[EventType[[]]] = None,
         **props,
         **props,
     ) -> "Form":
     ) -> "Form":

+ 7 - 15
reflex/components/radix/themes/components/slider.py

@@ -2,11 +2,11 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import List, Literal, Optional, Tuple, Union
+from typing import List, Literal, Optional, Union
 
 
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.core.breakpoints import Responsive
 from reflex.components.core.breakpoints import Responsive
-from reflex.event import EventHandler
+from reflex.event import EventHandler, identity_event
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
 from ..base import (
 from ..base import (
@@ -14,19 +14,11 @@ from ..base import (
     RadixThemesComponent,
     RadixThemesComponent,
 )
 )
 
 
-
-def on_value_event_spec(
-    value: Var[List[Union[int, float]]],
-) -> Tuple[Var[List[Union[int, float]]]]:
-    """Event handler spec for the value event.
-
-    Args:
-        value: The value of the event.
-
-    Returns:
-        The event handler spec.
-    """
-    return (value,)  # type: ignore
+on_value_event_spec = (
+    identity_event(list[Union[int, float]]),
+    identity_event(list[int]),
+    identity_event(list[float]),
+)
 
 
 
 
 class Slider(RadixThemesComponent):
 class Slider(RadixThemesComponent):

+ 21 - 7
reflex/components/radix/themes/components/slider.pyi

@@ -3,18 +3,20 @@
 # ------------------- DO NOT EDIT ----------------------
 # ------------------- DO NOT EDIT ----------------------
 # This file was generated by `reflex/utils/pyi_generator.py`!
 # This file was generated by `reflex/utils/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
+from typing import Any, Dict, List, Literal, Optional, Union, overload
 
 
 from reflex.components.core.breakpoints import Breakpoints
 from reflex.components.core.breakpoints import Breakpoints
-from reflex.event import EventType
+from reflex.event import EventType, identity_event
 from reflex.style import Style
 from reflex.style import Style
 from reflex.vars.base import Var
 from reflex.vars.base import Var
 
 
 from ..base import RadixThemesComponent
 from ..base import RadixThemesComponent
 
 
-def on_value_event_spec(
-    value: Var[List[Union[int, float]]],
-) -> Tuple[Var[List[Union[int, float]]]]: ...
+on_value_event_spec = (
+    identity_event(list[Union[int, float]]),
+    identity_event(list[int]),
+    identity_event(list[float]),
+)
 
 
 class Slider(RadixThemesComponent):
 class Slider(RadixThemesComponent):
     @overload
     @overload
@@ -138,7 +140,13 @@ class Slider(RadixThemesComponent):
         autofocus: Optional[bool] = None,
         autofocus: Optional[bool] = None,
         custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
         custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
         on_blur: Optional[EventType[[]]] = None,
         on_blur: Optional[EventType[[]]] = None,
-        on_change: Optional[EventType[List[Union[int, float]]]] = None,
+        on_change: Optional[
+            Union[
+                EventType[list[Union[int, float]]],
+                EventType[list[int]],
+                EventType[list[float]],
+            ]
+        ] = None,
         on_click: Optional[EventType[[]]] = None,
         on_click: Optional[EventType[[]]] = None,
         on_context_menu: Optional[EventType[[]]] = None,
         on_context_menu: Optional[EventType[[]]] = None,
         on_double_click: Optional[EventType[[]]] = None,
         on_double_click: Optional[EventType[[]]] = None,
@@ -153,7 +161,13 @@ class Slider(RadixThemesComponent):
         on_mouse_up: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
         on_unmount: Optional[EventType[[]]] = None,
         on_unmount: Optional[EventType[[]]] = None,
-        on_value_commit: Optional[EventType[List[Union[int, float]]]] = None,
+        on_value_commit: Optional[
+            Union[
+                EventType[list[Union[int, float]]],
+                EventType[list[int]],
+                EventType[list[float]],
+            ]
+        ] = None,
         **props,
         **props,
     ) -> "Slider":
     ) -> "Slider":
         """Create a Slider component.
         """Create a Slider component.

+ 138 - 16
reflex/event.py

@@ -29,8 +29,12 @@ from typing_extensions import ParamSpec, Protocol, get_args, get_origin
 
 
 from reflex import constants
 from reflex import constants
 from reflex.utils import console, format
 from reflex.utils import console, format
-from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
-from reflex.utils.types import ArgsSpec, GenericType
+from reflex.utils.exceptions import (
+    EventFnArgMismatch,
+    EventHandlerArgMismatch,
+    EventHandlerArgTypeMismatch,
+)
+from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
 from reflex.vars import VarData
 from reflex.vars import VarData
 from reflex.vars.base import (
 from reflex.vars.base import (
     LiteralVar,
     LiteralVar,
@@ -401,7 +405,9 @@ class EventChain(EventActionsMixin):
         default_factory=list
         default_factory=list
     )
     )
 
 
-    args_spec: Optional[Callable] = dataclasses.field(default=None)
+    args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field(
+        default=None
+    )
 
 
     invocation: Optional[Var] = dataclasses.field(default=None)
     invocation: Optional[Var] = dataclasses.field(default=None)
 
 
@@ -1053,7 +1059,8 @@ def get_hydrate_event(state) -> str:
 
 
 def call_event_handler(
 def call_event_handler(
     event_handler: EventHandler | EventSpec,
     event_handler: EventHandler | EventSpec,
-    arg_spec: ArgsSpec,
+    arg_spec: ArgsSpec | Sequence[ArgsSpec],
+    key: Optional[str] = None,
 ) -> EventSpec:
 ) -> EventSpec:
     """Call an event handler to get the event spec.
     """Call an event handler to get the event spec.
 
 
@@ -1064,12 +1071,16 @@ def call_event_handler(
     Args:
     Args:
         event_handler: The event handler.
         event_handler: The event handler.
         arg_spec: The lambda that define the argument(s) to pass to the event handler.
         arg_spec: The lambda that define the argument(s) to pass to the event handler.
+        key: The key to pass to the event handler.
 
 
     Raises:
     Raises:
         EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
         EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
 
 
     Returns:
     Returns:
         The event spec from calling the event handler.
         The event spec from calling the event handler.
+
+    # noqa: DAR401 failure
+
     """
     """
     parsed_args = parse_args_spec(arg_spec)  # type: ignore
     parsed_args = parse_args_spec(arg_spec)  # type: ignore
 
 
@@ -1077,19 +1088,113 @@ def call_event_handler(
         # Handle partial application of EventSpec args
         # Handle partial application of EventSpec args
         return event_handler.add_args(*parsed_args)
         return event_handler.add_args(*parsed_args)
 
 
-    args = inspect.getfullargspec(event_handler.fn).args
-    n_args = len(args) - 1  # subtract 1 for bound self arg
-    if n_args == len(parsed_args):
-        return event_handler(*parsed_args)  # type: ignore
-    else:
+    provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)
+
+    provided_callback_n_args = (
+        len(provided_callback_fullspec.args) - 1
+    )  # subtract 1 for bound self arg
+
+    if provided_callback_n_args != len(parsed_args):
         raise EventHandlerArgMismatch(
         raise EventHandlerArgMismatch(
             "The number of arguments accepted by "
             "The number of arguments accepted by "
-            f"{event_handler.fn.__qualname__} ({n_args}) "
+            f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
             "does not match the arguments passed by the event trigger: "
             "does not match the arguments passed by the event trigger: "
             f"{[str(v) for v in parsed_args]}\n"
             f"{[str(v) for v in parsed_args]}\n"
             "See https://reflex.dev/docs/events/event-arguments/"
             "See https://reflex.dev/docs/events/event-arguments/"
         )
         )
 
 
+    all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
+
+    event_spec_return_types = list(
+        filter(
+            lambda event_spec_return_type: event_spec_return_type is not None
+            and get_origin(event_spec_return_type) is tuple,
+            (get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
+        )
+    )
+
+    if event_spec_return_types:
+        failures = []
+
+        for event_spec_index, event_spec_return_type in enumerate(
+            event_spec_return_types
+        ):
+            args = get_args(event_spec_return_type)
+
+            args_types_without_vars = [
+                arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
+            ]
+
+            try:
+                type_hints_of_provided_callback = get_type_hints(event_handler.fn)
+            except NameError:
+                type_hints_of_provided_callback = {}
+
+            failed_type_check = False
+
+            # check that args of event handler are matching the spec if type hints are provided
+            for i, arg in enumerate(provided_callback_fullspec.args[1:]):
+                if arg not in type_hints_of_provided_callback:
+                    continue
+
+                try:
+                    compare_result = typehint_issubclass(
+                        args_types_without_vars[i], type_hints_of_provided_callback[arg]
+                    )
+                except TypeError:
+                    # TODO: In 0.7.0, remove this block and raise the exception
+                    # raise TypeError(
+                    #     f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
+                    # ) from e
+                    console.warn(
+                        f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
+                    )
+                    compare_result = False
+
+                if compare_result:
+                    continue
+                else:
+                    failure = EventHandlerArgTypeMismatch(
+                        f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
+                    )
+                    failures.append(failure)
+                    failed_type_check = True
+                    break
+
+            if not failed_type_check:
+                if event_spec_index:
+                    args = get_args(event_spec_return_types[0])
+
+                    args_types_without_vars = [
+                        arg if get_origin(arg) is not Var else get_args(arg)[0]
+                        for arg in args
+                    ]
+
+                    expect_string = ", ".join(
+                        repr(arg) for arg in args_types_without_vars
+                    ).replace("[", "\\[")
+
+                    given_string = ", ".join(
+                        repr(type_hints_of_provided_callback.get(arg, Any))
+                        for arg in provided_callback_fullspec.args[1:]
+                    ).replace("[", "\\[")
+
+                    console.warn(
+                        f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
+                        f"This may lead to unexpected behavior but is intentionally ignored for {key}."
+                    )
+                return event_handler(*parsed_args)
+
+        if failures:
+            console.deprecate(
+                "Mismatched event handler argument types",
+                "\n".join([str(f) for f in failures]),
+                "0.6.5",
+                "0.7.0",
+            )
+
+    return event_handler(*parsed_args)  # type: ignore
+
 
 
 def unwrap_var_annotation(annotation: GenericType):
 def unwrap_var_annotation(annotation: GenericType):
     """Unwrap a Var annotation or return it as is if it's not Var[X].
     """Unwrap a Var annotation or return it as is if it's not Var[X].
@@ -1128,7 +1233,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
     return annotation
     return annotation
 
 
 
 
-def parse_args_spec(arg_spec: ArgsSpec):
+def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
     """Parse the args provided in the ArgsSpec of an event trigger.
     """Parse the args provided in the ArgsSpec of an event trigger.
 
 
     Args:
     Args:
@@ -1137,6 +1242,8 @@ def parse_args_spec(arg_spec: ArgsSpec):
     Returns:
     Returns:
         The parsed args.
         The parsed args.
     """
     """
+    # if there's multiple, the first is the default
+    arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec
     spec = inspect.getfullargspec(arg_spec)
     spec = inspect.getfullargspec(arg_spec)
     annotations = get_type_hints(arg_spec)
     annotations = get_type_hints(arg_spec)
 
 
@@ -1152,13 +1259,18 @@ def parse_args_spec(arg_spec: ArgsSpec):
     )
     )
 
 
 
 
-def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
+def check_fn_match_arg_spec(
+    fn: Callable,
+    arg_spec: ArgsSpec,
+    key: Optional[str] = None,
+) -> List[Var]:
     """Ensures that the function signature matches the passed argument specification
     """Ensures that the function signature matches the passed argument specification
     or raises an EventFnArgMismatch if they do not.
     or raises an EventFnArgMismatch if they do not.
 
 
     Args:
     Args:
         fn: The function to be validated.
         fn: The function to be validated.
         arg_spec: The argument specification for the event trigger.
         arg_spec: The argument specification for the event trigger.
+        key: The key to pass to the event handler.
 
 
     Returns:
     Returns:
         The parsed arguments from the argument specification.
         The parsed arguments from the argument specification.
@@ -1184,7 +1296,11 @@ def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
     return parsed_args
     return parsed_args
 
 
 
 
-def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
+def call_event_fn(
+    fn: Callable,
+    arg_spec: ArgsSpec,
+    key: Optional[str] = None,
+) -> list[EventSpec] | Var:
     """Call a function to a list of event specs.
     """Call a function to a list of event specs.
 
 
     The function should return a single EventSpec, a list of EventSpecs, or a
     The function should return a single EventSpec, a list of EventSpecs, or a
@@ -1193,6 +1309,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     Args:
     Args:
         fn: The function to call.
         fn: The function to call.
         arg_spec: The argument spec for the event trigger.
         arg_spec: The argument spec for the event trigger.
+        key: The key to pass to the event handler.
 
 
     Returns:
     Returns:
         The event specs from calling the function or a Var.
         The event specs from calling the function or a Var.
@@ -1205,7 +1322,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     from reflex.utils.exceptions import EventHandlerValueError
     from reflex.utils.exceptions import EventHandlerValueError
 
 
     # Check that fn signature matches arg_spec
     # Check that fn signature matches arg_spec
-    parsed_args = check_fn_match_arg_spec(fn, arg_spec)
+    parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
 
 
     # Call the function with the parsed args.
     # Call the function with the parsed args.
     out = fn(*parsed_args)
     out = fn(*parsed_args)
@@ -1223,7 +1340,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     for e in out:
     for e in out:
         if isinstance(e, EventHandler):
         if isinstance(e, EventHandler):
             # An un-called EventHandler gets all of the args of the event trigger.
             # An un-called EventHandler gets all of the args of the event trigger.
-            e = call_event_handler(e, arg_spec)
+            e = call_event_handler(e, arg_spec, key=key)
 
 
         # Make sure the event spec is valid.
         # Make sure the event spec is valid.
         if not isinstance(e, EventSpec):
         if not isinstance(e, EventSpec):
@@ -1433,7 +1550,12 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
         Returns:
         Returns:
             The created LiteralEventChainVar instance.
             The created LiteralEventChainVar instance.
         """
         """
-        sig = inspect.signature(value.args_spec)  # type: ignore
+        arg_spec = (
+            value.args_spec[0]
+            if isinstance(value.args_spec, Sequence)
+            else value.args_spec
+        )
+        sig = inspect.signature(arg_spec)  # type: ignore
         if sig.parameters:
         if sig.parameters:
             arg_def = tuple((f"_{p}" for p in sig.parameters))
             arg_def = tuple((f"_{p}" for p in sig.parameters))
             arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
             arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])

+ 5 - 1
reflex/utils/exceptions.py

@@ -90,7 +90,11 @@ class MatchTypeError(ReflexError, TypeError):
 
 
 
 
 class EventHandlerArgMismatch(ReflexError, TypeError):
 class EventHandlerArgMismatch(ReflexError, TypeError):
-    """Raised when the number of args accepted by an EventHandler is differs from that provided by the event trigger."""
+    """Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""
+
+
+class EventHandlerArgTypeMismatch(ReflexError, TypeError):
+    """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
 
 
 
 
 class EventFnArgMismatch(ReflexError, TypeError):
 class EventFnArgMismatch(ReflexError, TypeError):

+ 32 - 16
reflex/utils/pyi_generator.py

@@ -490,7 +490,7 @@ def _generate_component_create_functiondef(
 
 
     def figure_out_return_type(annotation: Any):
     def figure_out_return_type(annotation: Any):
         if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
         if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
-            return ast.Name(id="Optional[EventType]")
+            return ast.Name(id="EventType")
 
 
         if not isinstance(annotation, str) and get_origin(annotation) is tuple:
         if not isinstance(annotation, str) and get_origin(annotation) is tuple:
             arguments = get_args(annotation)
             arguments = get_args(annotation)
@@ -509,20 +509,13 @@ def _generate_component_create_functiondef(
             # Create EventType using the joined string
             # Create EventType using the joined string
             event_type = ast.Name(id=f"EventType[{args_str}]")
             event_type = ast.Name(id=f"EventType[{args_str}]")
 
 
-            # Wrap in Optional
-            optional_type = ast.Subscript(
-                value=ast.Name(id="Optional"),
-                slice=ast.Index(value=event_type),
-                ctx=ast.Load(),
-            )
-
-            return ast.Name(id=ast.unparse(optional_type))
+            return event_type
 
 
         if isinstance(annotation, str) and annotation.startswith("Tuple["):
         if isinstance(annotation, str) and annotation.startswith("Tuple["):
             inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
             inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
 
 
             if inside_of_tuple == "()":
             if inside_of_tuple == "()":
-                return ast.Name(id="Optional[EventType[[]]]")
+                return ast.Name(id="EventType[[]]")
 
 
             arguments = [""]
             arguments = [""]
 
 
@@ -548,10 +541,8 @@ def _generate_component_create_functiondef(
                 for argument in arguments
                 for argument in arguments
             ]
             ]
 
 
-            return ast.Name(
-                id=f"Optional[EventType[{', '.join(arguments_without_var)}]]"
-            )
-        return ast.Name(id="Optional[EventType]")
+            return ast.Name(id=f"EventType[{', '.join(arguments_without_var)}]")
+        return ast.Name(id="EventType")
 
 
     event_triggers = clz().get_event_triggers()
     event_triggers = clz().get_event_triggers()
 
 
@@ -560,8 +551,33 @@ def _generate_component_create_functiondef(
         (
         (
             ast.arg(
             ast.arg(
                 arg=trigger,
                 arg=trigger,
-                annotation=figure_out_return_type(
-                    inspect.signature(event_triggers[trigger]).return_annotation
+                annotation=ast.Subscript(
+                    ast.Name("Optional"),
+                    ast.Index(  # type: ignore
+                        value=ast.Name(
+                            id=ast.unparse(
+                                figure_out_return_type(
+                                    inspect.signature(event_specs).return_annotation
+                                )
+                                if not isinstance(
+                                    event_specs := event_triggers[trigger], tuple
+                                )
+                                else ast.Subscript(
+                                    ast.Name("Union"),
+                                    ast.Tuple(
+                                        [
+                                            figure_out_return_type(
+                                                inspect.signature(
+                                                    event_spec
+                                                ).return_annotation
+                                            )
+                                            for event_spec in event_specs
+                                        ]
+                                    ),
+                                )
+                            )
+                        )
+                    ),
                 ),
                 ),
             ),
             ),
             ast.Constant(value=None),
             ast.Constant(value=None),

+ 66 - 0
reflex/utils/types.py

@@ -774,3 +774,69 @@ def validate_parameter_literals(func):
 # Store this here for performance.
 # Store this here for performance.
 StateBases = get_base_class(StateVar)
 StateBases = get_base_class(StateVar)
 StateIterBases = get_base_class(StateIterVar)
 StateIterBases = get_base_class(StateIterVar)
+
+
+def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
+    """Check if a type hint is a subclass of another type hint.
+
+    Args:
+        possible_subclass: The type hint to check.
+        possible_superclass: The type hint to check against.
+
+    Returns:
+        Whether the type hint is a subclass of the other type hint.
+    """
+    if possible_superclass is Any:
+        return True
+    if possible_subclass is Any:
+        return False
+
+    provided_type_origin = get_origin(possible_subclass)
+    accepted_type_origin = get_origin(possible_superclass)
+
+    if provided_type_origin is None and accepted_type_origin is None:
+        # In this case, we are dealing with a non-generic type, so we can use issubclass
+        return issubclass(possible_subclass, possible_superclass)
+
+    # Remove this check when Python 3.10 is the minimum supported version
+    if hasattr(types, "UnionType"):
+        provided_type_origin = (
+            Union if provided_type_origin is types.UnionType else provided_type_origin
+        )
+        accepted_type_origin = (
+            Union if accepted_type_origin is types.UnionType else accepted_type_origin
+        )
+
+    # Get type arguments (e.g., [float, int] for Dict[float, int])
+    provided_args = get_args(possible_subclass)
+    accepted_args = get_args(possible_superclass)
+
+    if accepted_type_origin is Union:
+        if provided_type_origin is not Union:
+            return any(
+                typehint_issubclass(possible_subclass, accepted_arg)
+                for accepted_arg in accepted_args
+            )
+        return all(
+            any(
+                typehint_issubclass(provided_arg, accepted_arg)
+                for accepted_arg in accepted_args
+            )
+            for provided_arg in provided_args
+        )
+
+    # Check if the origin of both types is the same (e.g., list for List[int])
+    # This probably should be issubclass instead of ==
+    if (provided_type_origin or possible_subclass) != (
+        accepted_type_origin or possible_superclass
+    ):
+        return False
+
+    # Ensure all specific types are compatible with accepted types
+    # Note this is not necessarily correct, as it doesn't check against contravariance and covariance
+    # It also ignores when the length of the arguments is different
+    return all(
+        typehint_issubclass(provided_arg, accepted_arg)
+        for provided_arg, accepted_arg in zip(provided_args, accepted_args)
+        if accepted_arg is not Any
+    )

+ 45 - 4
tests/units/components/test_component.py

@@ -20,13 +20,17 @@ from reflex.event import (
     EventChain,
     EventChain,
     EventHandler,
     EventHandler,
     empty_event,
     empty_event,
+    identity_event,
     input_event,
     input_event,
     parse_args_spec,
     parse_args_spec,
 )
 )
 from reflex.state import BaseState
 from reflex.state import BaseState
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
+from reflex.utils.exceptions import (
+    EventFnArgMismatch,
+    EventHandlerArgMismatch,
+)
 from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.vars import VarData
 from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.base import LiteralVar, Var
@@ -43,6 +47,18 @@ def test_state():
         def do_something_arg(self, arg):
         def do_something_arg(self, arg):
             pass
             pass
 
 
+        def do_something_with_bool(self, arg: bool):
+            pass
+
+        def do_something_with_int(self, arg: int):
+            pass
+
+        def do_something_with_list_int(self, arg: list[int]):
+            pass
+
+        def do_something_with_list_str(self, arg: list[str]):
+            pass
+
     return TestState
     return TestState
 
 
 
 
@@ -95,8 +111,10 @@ def component2() -> Type[Component]:
             """
             """
             return {
             return {
                 **super().get_event_triggers(),
                 **super().get_event_triggers(),
-                "on_open": lambda e0: [e0],
-                "on_close": lambda e0: [e0],
+                "on_open": identity_event(bool),
+                "on_close": identity_event(bool),
+                "on_user_visited_count_changed": identity_event(int),
+                "on_user_list_changed": identity_event(List[str]),
             }
             }
 
 
         def _get_imports(self) -> ParsedImportDict:
         def _get_imports(self) -> ParsedImportDict:
@@ -582,7 +600,14 @@ def test_get_event_triggers(component1, component2):
     assert component1().get_event_triggers().keys() == default_triggers
     assert component1().get_event_triggers().keys() == default_triggers
     assert (
     assert (
         component2().get_event_triggers().keys()
         component2().get_event_triggers().keys()
-        == {"on_open", "on_close", "on_prop_event"} | default_triggers
+        == {
+            "on_open",
+            "on_close",
+            "on_prop_event",
+            "on_user_visited_count_changed",
+            "on_user_list_changed",
+        }
+        | default_triggers
     )
     )
 
 
 
 
@@ -903,6 +928,22 @@ def test_invalid_event_handler_args(component2, test_state):
             on_prop_event=[test_state.do_something_arg, test_state.do_something]
             on_prop_event=[test_state.do_something_arg, test_state.do_something]
         )
         )
 
 
+    # Enable when 0.7.0 happens
+    # # Event Handler types must match
+    # with pytest.raises(EventHandlerArgTypeMismatch):
+    #     component2.create(
+    #         on_user_visited_count_changed=test_state.do_something_with_bool
+    #     )
+    # with pytest.raises(EventHandlerArgTypeMismatch):
+    #     component2.create(on_user_list_changed=test_state.do_something_with_int)
+    # with pytest.raises(EventHandlerArgTypeMismatch):
+    #     component2.create(on_user_list_changed=test_state.do_something_with_list_int)
+
+    # component2.create(on_open=test_state.do_something_with_int)
+    # component2.create(on_open=test_state.do_something_with_bool)
+    # component2.create(on_user_visited_count_changed=test_state.do_something_with_int)
+    # component2.create(on_user_list_changed=test_state.do_something_with_list_str)
+
     # lambda cannot return weird values.
     # lambda cannot return weird values.
     with pytest.raises(ValueError):
     with pytest.raises(ValueError):
         component2.create(on_click=lambda: 1)
         component2.create(on_click=lambda: 1)

+ 42 - 1
tests/units/utils/test_utils.py

@@ -2,7 +2,7 @@ import os
 import typing
 import typing
 from functools import cached_property
 from functools import cached_property
 from pathlib import Path
 from pathlib import Path
-from typing import Any, ClassVar, List, Literal, Type, Union
+from typing import Any, ClassVar, Dict, List, Literal, Type, Union
 
 
 import pytest
 import pytest
 import typer
 import typer
@@ -77,6 +77,47 @@ def test_is_generic_alias(cls: type, expected: bool):
     assert types.is_generic_alias(cls) == expected
     assert types.is_generic_alias(cls) == expected
 
 
 
 
+@pytest.mark.parametrize(
+    ("subclass", "superclass", "expected"),
+    [
+        *[
+            (base_type, base_type, True)
+            for base_type in [int, float, str, bool, list, dict]
+        ],
+        *[
+            (one_type, another_type, False)
+            for one_type in [int, float, str, list, dict]
+            for another_type in [int, float, str, list, dict]
+            if one_type != another_type
+        ],
+        (bool, int, True),
+        (int, bool, False),
+        (list, List, True),
+        (list, List[str], True),  # this is wrong, but it's a limitation of the function
+        (List, list, True),
+        (List[int], list, True),
+        (List[int], List, True),
+        (List[int], List[str], False),
+        (List[int], List[int], True),
+        (List[int], List[float], False),
+        (List[int], List[Union[int, float]], True),
+        (List[int], List[Union[float, str]], False),
+        (Union[int, float], List[Union[int, float]], False),
+        (Union[int, float], Union[int, float, str], True),
+        (Union[int, float], Union[str, float], False),
+        (Dict[str, int], Dict[str, int], True),
+        (Dict[str, bool], Dict[str, int], True),
+        (Dict[str, int], Dict[str, bool], False),
+        (Dict[str, Any], dict[str, str], False),
+        (Dict[str, str], dict[str, str], True),
+        (Dict[str, str], dict[str, Any], True),
+        (Dict[str, Any], dict[str, Any], True),
+    ],
+)
+def test_typehint_issubclass(subclass, superclass, expected):
+    assert types.typehint_issubclass(subclass, superclass) == expected
+
+
 def test_validate_invalid_bun_path(mocker):
 def test_validate_invalid_bun_path(mocker):
     """Test that an error is thrown when a custom specified bun path is not valid
     """Test that an error is thrown when a custom specified bun path is not valid
     or does not exist.
     or does not exist.