浏览代码

[ENG-3943]type check for event handler if spec arg are typed (#4046)

* type check for event handler if spec arg are typed

* fix the typecheck logic

* rearrange logic pieces

* add try except

* add try except around compare

* change form and improve type checking

* print key instead

* dang it darglint

* change wording

* add basic test to cover it

* add a slightly more complicated test

* challenge it a bit by doing small capital list

* add multiple argspec

* fix slider event order

* i hate 3.9

* add note for UnionType

* move function to types

* add a test for type hint is subclass

* make on submit dict str any

* add testing for dict cases

* add check against any

* accept dict str str

* bruh i used i twice

* escape strings and print actual error message

* disable the error and print deprecation warning instead

* disable tests

* fix doc message

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
Thomas Brandého 6 月之前
父节点
当前提交
c07eb2a6a0

+ 8 - 4
reflex/components/component.py

@@ -480,6 +480,7 @@ class Component(BaseComponent, ABC):
                 kwargs["event_triggers"][key] = self._create_event_chain(
                     value=value,  # type: ignore
                     args_spec=component_specific_triggers[key],
+                    key=key,
                 )
 
         # Remove any keys that were added as events.
@@ -540,12 +541,14 @@ class Component(BaseComponent, ABC):
             List[Union[EventHandler, EventSpec, EventVar]],
             Callable,
         ],
+        key: Optional[str] = None,
     ) -> Union[EventChain, Var]:
         """Create an event chain from a variety of input types.
 
         Args:
             args_spec: The args_spec of the event trigger being bound.
             value: The value to create the event chain from.
+            key: The key of the event trigger being bound.
 
         Returns:
             The event chain.
@@ -560,7 +563,7 @@ class Component(BaseComponent, ABC):
             elif isinstance(value, EventVar):
                 value = [value]
             elif issubclass(value._var_type, (EventChain, EventSpec)):
-                return self._create_event_chain(args_spec, value.guess_type())
+                return self._create_event_chain(args_spec, value.guess_type(), key=key)
             else:
                 raise ValueError(
                     f"Invalid event chain: {str(value)} of type {value._var_type}"
@@ -579,10 +582,10 @@ class Component(BaseComponent, ABC):
             for v in value:
                 if isinstance(v, (EventHandler, EventSpec)):
                     # Call the event handler to get the event.
-                    events.append(call_event_handler(v, args_spec))
+                    events.append(call_event_handler(v, args_spec, key=key))
                 elif isinstance(v, Callable):
                     # Call the lambda to get the event chain.
-                    result = call_event_fn(v, args_spec)
+                    result = call_event_fn(v, args_spec, key=key)
                     if isinstance(result, Var):
                         raise ValueError(
                             f"Invalid event chain: {v}. Cannot use a Var-returning "
@@ -599,7 +602,7 @@ class Component(BaseComponent, ABC):
             result = call_event_fn(value, args_spec)
             if isinstance(result, Var):
                 # Recursively call this function if the lambda returned an EventChain Var.
-                return self._create_event_chain(args_spec, result)
+                return self._create_event_chain(args_spec, result, key=key)
             events = [*result]
 
         # Otherwise, raise an error.
@@ -1722,6 +1725,7 @@ class CustomComponent(Component):
                     args_spec=event_triggers_in_component_declaration.get(
                         key, empty_event
                     ),
+                    key=key,
                 )
                 self.props[format.to_camel_case(key)] = value
                 continue

+ 10 - 1
reflex/components/el/elements/forms.py

@@ -111,6 +111,15 @@ def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]:
     return (FORM_DATA,)
 
 
+def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]:
+    """Event handler spec for the on_submit event.
+
+    Returns:
+        The event handler spec.
+    """
+    return (FORM_DATA,)
+
+
 class Form(BaseHTML):
     """Display the form element."""
 
@@ -150,7 +159,7 @@ class Form(BaseHTML):
     handle_submit_unique_name: Var[str]
 
     # Fired when the form is submitted
-    on_submit: EventHandler[on_submit_event_spec]
+    on_submit: EventHandler[on_submit_event_spec, on_submit_string_event_spec]
 
     @classmethod
     def create(cls, *children, **props):

+ 4 - 1
reflex/components/el/elements/forms.pyi

@@ -271,6 +271,7 @@ class Fieldset(Element):
         ...
 
 def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]: ...
+def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]: ...
 
 class Form(BaseHTML):
     @overload
@@ -337,7 +338,9 @@ class Form(BaseHTML):
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
-        on_submit: Optional[EventType[Dict[str, Any]]] = None,
+        on_submit: Optional[
+            Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
+        ] = None,
         on_unmount: Optional[EventType[[]]] = None,
         **props,
     ) -> "Form":

+ 9 - 3
reflex/components/radix/primitives/form.pyi

@@ -129,7 +129,9 @@ class FormRoot(FormComponent, HTMLForm):
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
-        on_submit: Optional[EventType[Dict[str, Any]]] = None,
+        on_submit: Optional[
+            Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
+        ] = None,
         on_unmount: Optional[EventType[[]]] = None,
         **props,
     ) -> "FormRoot":
@@ -596,7 +598,9 @@ class Form(FormRoot):
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
-        on_submit: Optional[EventType[Dict[str, Any]]] = None,
+        on_submit: Optional[
+            Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
+        ] = None,
         on_unmount: Optional[EventType[[]]] = None,
         **props,
     ) -> "Form":
@@ -720,7 +724,9 @@ class FormNamespace(ComponentNamespace):
         on_mouse_over: Optional[EventType[[]]] = None,
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
-        on_submit: Optional[EventType[Dict[str, Any]]] = None,
+        on_submit: Optional[
+            Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]]
+        ] = None,
         on_unmount: Optional[EventType[[]]] = None,
         **props,
     ) -> "Form":

+ 7 - 15
reflex/components/radix/themes/components/slider.py

@@ -2,11 +2,11 @@
 
 from __future__ import annotations
 
-from typing import List, Literal, Optional, Tuple, Union
+from typing import List, Literal, Optional, Union
 
 from reflex.components.component import Component
 from reflex.components.core.breakpoints import Responsive
-from reflex.event import EventHandler
+from reflex.event import EventHandler, identity_event
 from reflex.vars.base import Var
 
 from ..base import (
@@ -14,19 +14,11 @@ from ..base import (
     RadixThemesComponent,
 )
 
-
-def on_value_event_spec(
-    value: Var[List[Union[int, float]]],
-) -> Tuple[Var[List[Union[int, float]]]]:
-    """Event handler spec for the value event.
-
-    Args:
-        value: The value of the event.
-
-    Returns:
-        The event handler spec.
-    """
-    return (value,)  # type: ignore
+on_value_event_spec = (
+    identity_event(list[Union[int, float]]),
+    identity_event(list[int]),
+    identity_event(list[float]),
+)
 
 
 class Slider(RadixThemesComponent):

+ 21 - 7
reflex/components/radix/themes/components/slider.pyi

@@ -3,18 +3,20 @@
 # ------------------- DO NOT EDIT ----------------------
 # This file was generated by `reflex/utils/pyi_generator.py`!
 # ------------------------------------------------------
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
+from typing import Any, Dict, List, Literal, Optional, Union, overload
 
 from reflex.components.core.breakpoints import Breakpoints
-from reflex.event import EventType
+from reflex.event import EventType, identity_event
 from reflex.style import Style
 from reflex.vars.base import Var
 
 from ..base import RadixThemesComponent
 
-def on_value_event_spec(
-    value: Var[List[Union[int, float]]],
-) -> Tuple[Var[List[Union[int, float]]]]: ...
+on_value_event_spec = (
+    identity_event(list[Union[int, float]]),
+    identity_event(list[int]),
+    identity_event(list[float]),
+)
 
 class Slider(RadixThemesComponent):
     @overload
@@ -138,7 +140,13 @@ class Slider(RadixThemesComponent):
         autofocus: Optional[bool] = None,
         custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
         on_blur: Optional[EventType[[]]] = None,
-        on_change: Optional[EventType[List[Union[int, float]]]] = None,
+        on_change: Optional[
+            Union[
+                EventType[list[Union[int, float]]],
+                EventType[list[int]],
+                EventType[list[float]],
+            ]
+        ] = None,
         on_click: Optional[EventType[[]]] = None,
         on_context_menu: Optional[EventType[[]]] = None,
         on_double_click: Optional[EventType[[]]] = None,
@@ -153,7 +161,13 @@ class Slider(RadixThemesComponent):
         on_mouse_up: Optional[EventType[[]]] = None,
         on_scroll: Optional[EventType[[]]] = None,
         on_unmount: Optional[EventType[[]]] = None,
-        on_value_commit: Optional[EventType[List[Union[int, float]]]] = None,
+        on_value_commit: Optional[
+            Union[
+                EventType[list[Union[int, float]]],
+                EventType[list[int]],
+                EventType[list[float]],
+            ]
+        ] = None,
         **props,
     ) -> "Slider":
         """Create a Slider component.

+ 138 - 16
reflex/event.py

@@ -29,8 +29,12 @@ from typing_extensions import ParamSpec, Protocol, get_args, get_origin
 
 from reflex import constants
 from reflex.utils import console, format
-from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
-from reflex.utils.types import ArgsSpec, GenericType
+from reflex.utils.exceptions import (
+    EventFnArgMismatch,
+    EventHandlerArgMismatch,
+    EventHandlerArgTypeMismatch,
+)
+from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
 from reflex.vars import VarData
 from reflex.vars.base import (
     LiteralVar,
@@ -401,7 +405,9 @@ class EventChain(EventActionsMixin):
         default_factory=list
     )
 
-    args_spec: Optional[Callable] = dataclasses.field(default=None)
+    args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field(
+        default=None
+    )
 
     invocation: Optional[Var] = dataclasses.field(default=None)
 
@@ -1053,7 +1059,8 @@ def get_hydrate_event(state) -> str:
 
 def call_event_handler(
     event_handler: EventHandler | EventSpec,
-    arg_spec: ArgsSpec,
+    arg_spec: ArgsSpec | Sequence[ArgsSpec],
+    key: Optional[str] = None,
 ) -> EventSpec:
     """Call an event handler to get the event spec.
 
@@ -1064,12 +1071,16 @@ def call_event_handler(
     Args:
         event_handler: The event handler.
         arg_spec: The lambda that define the argument(s) to pass to the event handler.
+        key: The key to pass to the event handler.
 
     Raises:
         EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
 
     Returns:
         The event spec from calling the event handler.
+
+    # noqa: DAR401 failure
+
     """
     parsed_args = parse_args_spec(arg_spec)  # type: ignore
 
@@ -1077,19 +1088,113 @@ def call_event_handler(
         # Handle partial application of EventSpec args
         return event_handler.add_args(*parsed_args)
 
-    args = inspect.getfullargspec(event_handler.fn).args
-    n_args = len(args) - 1  # subtract 1 for bound self arg
-    if n_args == len(parsed_args):
-        return event_handler(*parsed_args)  # type: ignore
-    else:
+    provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)
+
+    provided_callback_n_args = (
+        len(provided_callback_fullspec.args) - 1
+    )  # subtract 1 for bound self arg
+
+    if provided_callback_n_args != len(parsed_args):
         raise EventHandlerArgMismatch(
             "The number of arguments accepted by "
-            f"{event_handler.fn.__qualname__} ({n_args}) "
+            f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
             "does not match the arguments passed by the event trigger: "
             f"{[str(v) for v in parsed_args]}\n"
             "See https://reflex.dev/docs/events/event-arguments/"
         )
 
+    all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
+
+    event_spec_return_types = list(
+        filter(
+            lambda event_spec_return_type: event_spec_return_type is not None
+            and get_origin(event_spec_return_type) is tuple,
+            (get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
+        )
+    )
+
+    if event_spec_return_types:
+        failures = []
+
+        for event_spec_index, event_spec_return_type in enumerate(
+            event_spec_return_types
+        ):
+            args = get_args(event_spec_return_type)
+
+            args_types_without_vars = [
+                arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
+            ]
+
+            try:
+                type_hints_of_provided_callback = get_type_hints(event_handler.fn)
+            except NameError:
+                type_hints_of_provided_callback = {}
+
+            failed_type_check = False
+
+            # check that args of event handler are matching the spec if type hints are provided
+            for i, arg in enumerate(provided_callback_fullspec.args[1:]):
+                if arg not in type_hints_of_provided_callback:
+                    continue
+
+                try:
+                    compare_result = typehint_issubclass(
+                        args_types_without_vars[i], type_hints_of_provided_callback[arg]
+                    )
+                except TypeError:
+                    # TODO: In 0.7.0, remove this block and raise the exception
+                    # raise TypeError(
+                    #     f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
+                    # ) from e
+                    console.warn(
+                        f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
+                    )
+                    compare_result = False
+
+                if compare_result:
+                    continue
+                else:
+                    failure = EventHandlerArgTypeMismatch(
+                        f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
+                    )
+                    failures.append(failure)
+                    failed_type_check = True
+                    break
+
+            if not failed_type_check:
+                if event_spec_index:
+                    args = get_args(event_spec_return_types[0])
+
+                    args_types_without_vars = [
+                        arg if get_origin(arg) is not Var else get_args(arg)[0]
+                        for arg in args
+                    ]
+
+                    expect_string = ", ".join(
+                        repr(arg) for arg in args_types_without_vars
+                    ).replace("[", "\\[")
+
+                    given_string = ", ".join(
+                        repr(type_hints_of_provided_callback.get(arg, Any))
+                        for arg in provided_callback_fullspec.args[1:]
+                    ).replace("[", "\\[")
+
+                    console.warn(
+                        f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
+                        f"This may lead to unexpected behavior but is intentionally ignored for {key}."
+                    )
+                return event_handler(*parsed_args)
+
+        if failures:
+            console.deprecate(
+                "Mismatched event handler argument types",
+                "\n".join([str(f) for f in failures]),
+                "0.6.5",
+                "0.7.0",
+            )
+
+    return event_handler(*parsed_args)  # type: ignore
+
 
 def unwrap_var_annotation(annotation: GenericType):
     """Unwrap a Var annotation or return it as is if it's not Var[X].
@@ -1128,7 +1233,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
     return annotation
 
 
-def parse_args_spec(arg_spec: ArgsSpec):
+def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
     """Parse the args provided in the ArgsSpec of an event trigger.
 
     Args:
@@ -1137,6 +1242,8 @@ def parse_args_spec(arg_spec: ArgsSpec):
     Returns:
         The parsed args.
     """
+    # if there's multiple, the first is the default
+    arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec
     spec = inspect.getfullargspec(arg_spec)
     annotations = get_type_hints(arg_spec)
 
@@ -1152,13 +1259,18 @@ def parse_args_spec(arg_spec: ArgsSpec):
     )
 
 
-def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
+def check_fn_match_arg_spec(
+    fn: Callable,
+    arg_spec: ArgsSpec,
+    key: Optional[str] = None,
+) -> List[Var]:
     """Ensures that the function signature matches the passed argument specification
     or raises an EventFnArgMismatch if they do not.
 
     Args:
         fn: The function to be validated.
         arg_spec: The argument specification for the event trigger.
+        key: The key to pass to the event handler.
 
     Returns:
         The parsed arguments from the argument specification.
@@ -1184,7 +1296,11 @@ def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
     return parsed_args
 
 
-def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
+def call_event_fn(
+    fn: Callable,
+    arg_spec: ArgsSpec,
+    key: Optional[str] = None,
+) -> list[EventSpec] | Var:
     """Call a function to a list of event specs.
 
     The function should return a single EventSpec, a list of EventSpecs, or a
@@ -1193,6 +1309,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     Args:
         fn: The function to call.
         arg_spec: The argument spec for the event trigger.
+        key: The key to pass to the event handler.
 
     Returns:
         The event specs from calling the function or a Var.
@@ -1205,7 +1322,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     from reflex.utils.exceptions import EventHandlerValueError
 
     # Check that fn signature matches arg_spec
-    parsed_args = check_fn_match_arg_spec(fn, arg_spec)
+    parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
 
     # Call the function with the parsed args.
     out = fn(*parsed_args)
@@ -1223,7 +1340,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     for e in out:
         if isinstance(e, EventHandler):
             # An un-called EventHandler gets all of the args of the event trigger.
-            e = call_event_handler(e, arg_spec)
+            e = call_event_handler(e, arg_spec, key=key)
 
         # Make sure the event spec is valid.
         if not isinstance(e, EventSpec):
@@ -1433,7 +1550,12 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
         Returns:
             The created LiteralEventChainVar instance.
         """
-        sig = inspect.signature(value.args_spec)  # type: ignore
+        arg_spec = (
+            value.args_spec[0]
+            if isinstance(value.args_spec, Sequence)
+            else value.args_spec
+        )
+        sig = inspect.signature(arg_spec)  # type: ignore
         if sig.parameters:
             arg_def = tuple((f"_{p}" for p in sig.parameters))
             arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])

+ 5 - 1
reflex/utils/exceptions.py

@@ -90,7 +90,11 @@ class MatchTypeError(ReflexError, TypeError):
 
 
 class EventHandlerArgMismatch(ReflexError, TypeError):
-    """Raised when the number of args accepted by an EventHandler is differs from that provided by the event trigger."""
+    """Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""
+
+
+class EventHandlerArgTypeMismatch(ReflexError, TypeError):
+    """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
 
 
 class EventFnArgMismatch(ReflexError, TypeError):

+ 32 - 16
reflex/utils/pyi_generator.py

@@ -490,7 +490,7 @@ def _generate_component_create_functiondef(
 
     def figure_out_return_type(annotation: Any):
         if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
-            return ast.Name(id="Optional[EventType]")
+            return ast.Name(id="EventType")
 
         if not isinstance(annotation, str) and get_origin(annotation) is tuple:
             arguments = get_args(annotation)
@@ -509,20 +509,13 @@ def _generate_component_create_functiondef(
             # Create EventType using the joined string
             event_type = ast.Name(id=f"EventType[{args_str}]")
 
-            # Wrap in Optional
-            optional_type = ast.Subscript(
-                value=ast.Name(id="Optional"),
-                slice=ast.Index(value=event_type),
-                ctx=ast.Load(),
-            )
-
-            return ast.Name(id=ast.unparse(optional_type))
+            return event_type
 
         if isinstance(annotation, str) and annotation.startswith("Tuple["):
             inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
 
             if inside_of_tuple == "()":
-                return ast.Name(id="Optional[EventType[[]]]")
+                return ast.Name(id="EventType[[]]")
 
             arguments = [""]
 
@@ -548,10 +541,8 @@ def _generate_component_create_functiondef(
                 for argument in arguments
             ]
 
-            return ast.Name(
-                id=f"Optional[EventType[{', '.join(arguments_without_var)}]]"
-            )
-        return ast.Name(id="Optional[EventType]")
+            return ast.Name(id=f"EventType[{', '.join(arguments_without_var)}]")
+        return ast.Name(id="EventType")
 
     event_triggers = clz().get_event_triggers()
 
@@ -560,8 +551,33 @@ def _generate_component_create_functiondef(
         (
             ast.arg(
                 arg=trigger,
-                annotation=figure_out_return_type(
-                    inspect.signature(event_triggers[trigger]).return_annotation
+                annotation=ast.Subscript(
+                    ast.Name("Optional"),
+                    ast.Index(  # type: ignore
+                        value=ast.Name(
+                            id=ast.unparse(
+                                figure_out_return_type(
+                                    inspect.signature(event_specs).return_annotation
+                                )
+                                if not isinstance(
+                                    event_specs := event_triggers[trigger], tuple
+                                )
+                                else ast.Subscript(
+                                    ast.Name("Union"),
+                                    ast.Tuple(
+                                        [
+                                            figure_out_return_type(
+                                                inspect.signature(
+                                                    event_spec
+                                                ).return_annotation
+                                            )
+                                            for event_spec in event_specs
+                                        ]
+                                    ),
+                                )
+                            )
+                        )
+                    ),
                 ),
             ),
             ast.Constant(value=None),

+ 66 - 0
reflex/utils/types.py

@@ -774,3 +774,69 @@ def validate_parameter_literals(func):
 # Store this here for performance.
 StateBases = get_base_class(StateVar)
 StateIterBases = get_base_class(StateIterVar)
+
+
+def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
+    """Check if a type hint is a subclass of another type hint.
+
+    Args:
+        possible_subclass: The type hint to check.
+        possible_superclass: The type hint to check against.
+
+    Returns:
+        Whether the type hint is a subclass of the other type hint.
+    """
+    if possible_superclass is Any:
+        return True
+    if possible_subclass is Any:
+        return False
+
+    provided_type_origin = get_origin(possible_subclass)
+    accepted_type_origin = get_origin(possible_superclass)
+
+    if provided_type_origin is None and accepted_type_origin is None:
+        # In this case, we are dealing with a non-generic type, so we can use issubclass
+        return issubclass(possible_subclass, possible_superclass)
+
+    # Remove this check when Python 3.10 is the minimum supported version
+    if hasattr(types, "UnionType"):
+        provided_type_origin = (
+            Union if provided_type_origin is types.UnionType else provided_type_origin
+        )
+        accepted_type_origin = (
+            Union if accepted_type_origin is types.UnionType else accepted_type_origin
+        )
+
+    # Get type arguments (e.g., [float, int] for Dict[float, int])
+    provided_args = get_args(possible_subclass)
+    accepted_args = get_args(possible_superclass)
+
+    if accepted_type_origin is Union:
+        if provided_type_origin is not Union:
+            return any(
+                typehint_issubclass(possible_subclass, accepted_arg)
+                for accepted_arg in accepted_args
+            )
+        return all(
+            any(
+                typehint_issubclass(provided_arg, accepted_arg)
+                for accepted_arg in accepted_args
+            )
+            for provided_arg in provided_args
+        )
+
+    # Check if the origin of both types is the same (e.g., list for List[int])
+    # This probably should be issubclass instead of ==
+    if (provided_type_origin or possible_subclass) != (
+        accepted_type_origin or possible_superclass
+    ):
+        return False
+
+    # Ensure all specific types are compatible with accepted types
+    # Note this is not necessarily correct, as it doesn't check against contravariance and covariance
+    # It also ignores when the length of the arguments is different
+    return all(
+        typehint_issubclass(provided_arg, accepted_arg)
+        for provided_arg, accepted_arg in zip(provided_args, accepted_args)
+        if accepted_arg is not Any
+    )

+ 45 - 4
tests/units/components/test_component.py

@@ -20,13 +20,17 @@ from reflex.event import (
     EventChain,
     EventHandler,
     empty_event,
+    identity_event,
     input_event,
     parse_args_spec,
 )
 from reflex.state import BaseState
 from reflex.style import Style
 from reflex.utils import imports
-from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
+from reflex.utils.exceptions import (
+    EventFnArgMismatch,
+    EventHandlerArgMismatch,
+)
 from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
@@ -43,6 +47,18 @@ def test_state():
         def do_something_arg(self, arg):
             pass
 
+        def do_something_with_bool(self, arg: bool):
+            pass
+
+        def do_something_with_int(self, arg: int):
+            pass
+
+        def do_something_with_list_int(self, arg: list[int]):
+            pass
+
+        def do_something_with_list_str(self, arg: list[str]):
+            pass
+
     return TestState
 
 
@@ -95,8 +111,10 @@ def component2() -> Type[Component]:
             """
             return {
                 **super().get_event_triggers(),
-                "on_open": lambda e0: [e0],
-                "on_close": lambda e0: [e0],
+                "on_open": identity_event(bool),
+                "on_close": identity_event(bool),
+                "on_user_visited_count_changed": identity_event(int),
+                "on_user_list_changed": identity_event(List[str]),
             }
 
         def _get_imports(self) -> ParsedImportDict:
@@ -582,7 +600,14 @@ def test_get_event_triggers(component1, component2):
     assert component1().get_event_triggers().keys() == default_triggers
     assert (
         component2().get_event_triggers().keys()
-        == {"on_open", "on_close", "on_prop_event"} | default_triggers
+        == {
+            "on_open",
+            "on_close",
+            "on_prop_event",
+            "on_user_visited_count_changed",
+            "on_user_list_changed",
+        }
+        | default_triggers
     )
 
 
@@ -903,6 +928,22 @@ def test_invalid_event_handler_args(component2, test_state):
             on_prop_event=[test_state.do_something_arg, test_state.do_something]
         )
 
+    # Enable when 0.7.0 happens
+    # # Event Handler types must match
+    # with pytest.raises(EventHandlerArgTypeMismatch):
+    #     component2.create(
+    #         on_user_visited_count_changed=test_state.do_something_with_bool
+    #     )
+    # with pytest.raises(EventHandlerArgTypeMismatch):
+    #     component2.create(on_user_list_changed=test_state.do_something_with_int)
+    # with pytest.raises(EventHandlerArgTypeMismatch):
+    #     component2.create(on_user_list_changed=test_state.do_something_with_list_int)
+
+    # component2.create(on_open=test_state.do_something_with_int)
+    # component2.create(on_open=test_state.do_something_with_bool)
+    # component2.create(on_user_visited_count_changed=test_state.do_something_with_int)
+    # component2.create(on_user_list_changed=test_state.do_something_with_list_str)
+
     # lambda cannot return weird values.
     with pytest.raises(ValueError):
         component2.create(on_click=lambda: 1)

+ 42 - 1
tests/units/utils/test_utils.py

@@ -2,7 +2,7 @@ import os
 import typing
 from functools import cached_property
 from pathlib import Path
-from typing import Any, ClassVar, List, Literal, Type, Union
+from typing import Any, ClassVar, Dict, List, Literal, Type, Union
 
 import pytest
 import typer
@@ -77,6 +77,47 @@ def test_is_generic_alias(cls: type, expected: bool):
     assert types.is_generic_alias(cls) == expected
 
 
+@pytest.mark.parametrize(
+    ("subclass", "superclass", "expected"),
+    [
+        *[
+            (base_type, base_type, True)
+            for base_type in [int, float, str, bool, list, dict]
+        ],
+        *[
+            (one_type, another_type, False)
+            for one_type in [int, float, str, list, dict]
+            for another_type in [int, float, str, list, dict]
+            if one_type != another_type
+        ],
+        (bool, int, True),
+        (int, bool, False),
+        (list, List, True),
+        (list, List[str], True),  # this is wrong, but it's a limitation of the function
+        (List, list, True),
+        (List[int], list, True),
+        (List[int], List, True),
+        (List[int], List[str], False),
+        (List[int], List[int], True),
+        (List[int], List[float], False),
+        (List[int], List[Union[int, float]], True),
+        (List[int], List[Union[float, str]], False),
+        (Union[int, float], List[Union[int, float]], False),
+        (Union[int, float], Union[int, float, str], True),
+        (Union[int, float], Union[str, float], False),
+        (Dict[str, int], Dict[str, int], True),
+        (Dict[str, bool], Dict[str, int], True),
+        (Dict[str, int], Dict[str, bool], False),
+        (Dict[str, Any], dict[str, str], False),
+        (Dict[str, str], dict[str, str], True),
+        (Dict[str, str], dict[str, Any], True),
+        (Dict[str, Any], dict[str, Any], True),
+    ],
+)
+def test_typehint_issubclass(subclass, superclass, expected):
+    assert types.typehint_issubclass(subclass, superclass) == expected
+
+
 def test_validate_invalid_bun_path(mocker):
     """Test that an error is thrown when a custom specified bun path is not valid
     or does not exist.