Browse Source

Fix annotated EventHandler (#3076)

Masen Furer 1 year ago
parent
commit
fc0be257a3
3 changed files with 27 additions and 17 deletions
  1. 5 3
      reflex/components/component.py
  2. 7 6
      reflex/event.py
  3. 15 8
      tests/components/test_component.py

+ 5 - 3
reflex/components/component.py

@@ -474,9 +474,11 @@ class Component(BaseComponent, ABC):
         # e.g. variable declared as EventHandler types.
         for field in self.get_fields().values():
             if types._issubclass(field.type_, EventHandler):
-                default_triggers[field.name] = getattr(
-                    field.type_, "args_spec", lambda: []
-                )
+                args_spec = None
+                annotation = field.annotation
+                if hasattr(annotation, "__metadata__"):
+                    args_spec = annotation.__metadata__[0]
+                default_triggers[field.name] = args_spec or (lambda: [])
         return default_triggers
 
     def __repr__(self) -> str:

+ 7 - 6
reflex/event.py

@@ -13,7 +13,6 @@ from typing import (
     Optional,
     Tuple,
     Union,
-    _GenericAlias,  # type: ignore
     get_type_hints,
 )
 
@@ -23,6 +22,11 @@ from reflex.utils import console, format
 from reflex.utils.types import ArgsSpec
 from reflex.vars import BaseVar, Var
 
+try:
+    from typing import Annotated
+except ImportError:
+    from typing_extensions import Annotated
+
 
 class Event(Base):
     """An event that describes any state change in the app."""
@@ -118,7 +122,7 @@ class EventHandler(EventActionsMixin):
         frozen = True
 
     @classmethod
-    def __class_getitem__(cls, args_spec: str) -> _GenericAlias:
+    def __class_getitem__(cls, args_spec: str) -> Annotated:
         """Get a typed EventHandler.
 
         Args:
@@ -127,10 +131,7 @@ class EventHandler(EventActionsMixin):
         Returns:
             The EventHandler class item.
         """
-        gen = _GenericAlias(cls, Any)
-        # Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute.
-        gen.args_spec = args_spec
-        return gen
+        return Annotated[cls, args_spec]
 
     @property
     def is_background(self) -> bool:

+ 15 - 8
tests/components/test_component.py

@@ -15,7 +15,7 @@ from reflex.components.component import (
     custom_component,
 )
 from reflex.constants import EventTriggers
-from reflex.event import EventChain, EventHandler
+from reflex.event import EventChain, EventHandler, parse_args_spec
 from reflex.state import BaseState
 from reflex.style import Style
 from reflex.utils import imports
@@ -1542,11 +1542,12 @@ def test_custom_component_declare_event_handlers_in_fields():
             """
             return {
                 **super().get_event_triggers(),
-                "on_a": lambda e: [e],
-                "on_b": lambda e: [e.target.value],
-                "on_c": lambda e: [],
+                "on_a": lambda e0: [e0],
+                "on_b": lambda e0: [e0.target.value],
+                "on_c": lambda e0: [],
                 "on_d": lambda: [],
                 "on_e": lambda: [],
+                "on_f": lambda a, b, c: [c, b, a],
             }
 
     class TestComponent(Component):
@@ -1555,10 +1556,16 @@ def test_custom_component_declare_event_handlers_in_fields():
         on_c: EventHandler[lambda e0: []]
         on_d: EventHandler[lambda: []]
         on_e: EventHandler
+        on_f: EventHandler[lambda a, b, c: [c, b, a]]
 
     custom_component = ReferenceComponent.create()
     test_component = TestComponent.create()
-    assert (
-        custom_component.get_event_triggers().keys()
-        == test_component.get_event_triggers().keys()
-    )
+    custom_triggers = custom_component.get_event_triggers()
+    test_triggers = test_component.get_event_triggers()
+    assert custom_triggers.keys() == test_triggers.keys()
+    for trigger_name in custom_component.get_event_triggers():
+        for v1, v2 in zip(
+            parse_args_spec(test_triggers[trigger_name]),
+            parse_args_spec(custom_triggers[trigger_name]),
+        ):
+            assert v1.equals(v2)