1
0
Эх сурвалжийг харах

[REF-2272] Support declaring EventHandlers directly in component (#2952)

Martin Xu 1 жил өмнө
parent
commit
f372402ee4

+ 1 - 0
reflex/__init__.py

@@ -131,6 +131,7 @@ _MAPPING = {
     "reflex.event": [
     "reflex.event": [
         "event",
         "event",
         "EventChain",
         "EventChain",
+        "EventHandler",
         "background",
         "background",
         "call_script",
         "call_script",
         "clear_local_storage",
         "clear_local_storage",

+ 1 - 0
reflex/__init__.pyi

@@ -113,6 +113,7 @@ from reflex import constants as constants
 from reflex.constants import Env as Env
 from reflex.constants import Env as Env
 from reflex import event as event
 from reflex import event as event
 from reflex.event import EventChain as EventChain
 from reflex.event import EventChain as EventChain
+from reflex.event import EventHandler as EventHandler
 from reflex.event import background as background
 from reflex.event import background as background
 from reflex.event import call_script as call_script
 from reflex.event import call_script as call_script
 from reflex.event import clear_local_storage as clear_local_storage
 from reflex.event import clear_local_storage as clear_local_storage

+ 32 - 16
reflex/components/component.py

@@ -236,6 +236,8 @@ class Component(BaseComponent, ABC):
             if types._issubclass(field.type_, Var):
             if types._issubclass(field.type_, Var):
                 field.required = False
                 field.required = False
                 field.default = Var.create(field.default)
                 field.default = Var.create(field.default)
+            elif types._issubclass(field.type_, EventHandler):
+                field.required = False
 
 
         # Ensure renamed props from parent classes are applied to the subclass.
         # Ensure renamed props from parent classes are applied to the subclass.
         if cls._rename_props:
         if cls._rename_props:
@@ -272,7 +274,8 @@ class Component(BaseComponent, ABC):
 
 
         # Get the component fields, triggers, and props.
         # Get the component fields, triggers, and props.
         fields = self.get_fields()
         fields = self.get_fields()
-        triggers = self.get_event_triggers().keys()
+        component_specific_triggers = self.get_event_triggers()
+        triggers = component_specific_triggers.keys()
         props = self.get_props()
         props = self.get_props()
 
 
         # Add any events triggers.
         # Add any events triggers.
@@ -327,7 +330,9 @@ class Component(BaseComponent, ABC):
             # Check if the key is an event trigger.
             # Check if the key is an event trigger.
             if key in triggers:
             if key in triggers:
                 # Temporarily disable full control for event triggers.
                 # Temporarily disable full control for event triggers.
-                kwargs["event_triggers"][key] = self._create_event_chain(key, value)
+                kwargs["event_triggers"][key] = self._create_event_chain(
+                    value=value, args_spec=component_specific_triggers[key]
+                )
 
 
         # Remove any keys that were added as events.
         # Remove any keys that were added as events.
         for key in kwargs["event_triggers"]:
         for key in kwargs["event_triggers"]:
@@ -359,7 +364,7 @@ class Component(BaseComponent, ABC):
 
 
     def _create_event_chain(
     def _create_event_chain(
         self,
         self,
-        event_trigger: str,
+        args_spec: Any,
         value: Union[
         value: Union[
             Var, EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], Callable
             Var, EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], Callable
         ],
         ],
@@ -367,7 +372,7 @@ class Component(BaseComponent, ABC):
         """Create an event chain from a variety of input types.
         """Create an event chain from a variety of input types.
 
 
         Args:
         Args:
-            event_trigger: The event trigger to bind the chain to.
+            args_spec: The args_spec of the the event trigger being bound.
             value: The value to create the event chain from.
             value: The value to create the event chain from.
 
 
         Returns:
         Returns:
@@ -376,9 +381,6 @@ class Component(BaseComponent, ABC):
         Raises:
         Raises:
             ValueError: If the value is not a valid event chain.
             ValueError: If the value is not a valid event chain.
         """
         """
-        # Check if the trigger is a controlled event.
-        triggers = self.get_event_triggers()
-
         # If it's an event chain var, return it.
         # If it's an event chain var, return it.
         if isinstance(value, Var):
         if isinstance(value, Var):
             if value._var_type is not EventChain:
             if value._var_type is not EventChain:
@@ -388,8 +390,6 @@ class Component(BaseComponent, ABC):
             # Trust that the caller knows what they're doing passing an EventChain directly
             # Trust that the caller knows what they're doing passing an EventChain directly
             return value
             return value
 
 
-        arg_spec = triggers.get(event_trigger, lambda: [])
-
         # If the input is a single event handler, wrap it in a list.
         # If the input is a single event handler, wrap it in a list.
         if isinstance(value, (EventHandler, EventSpec)):
         if isinstance(value, (EventHandler, EventSpec)):
             value = [value]
             value = [value]
@@ -401,7 +401,7 @@ class Component(BaseComponent, ABC):
                 if isinstance(v, EventHandler):
                 if isinstance(v, EventHandler):
                     # Call the event handler to get the event.
                     # Call the event handler to get the event.
                     try:
                     try:
-                        event = call_event_handler(v, arg_spec)  # type: ignore
+                        event = call_event_handler(v, args_spec)
                     except ValueError as err:
                     except ValueError as err:
                         raise ValueError(
                         raise ValueError(
                             f" {err} defined in the `{type(self).__name__}` component"
                             f" {err} defined in the `{type(self).__name__}` component"
@@ -414,13 +414,13 @@ class Component(BaseComponent, ABC):
                     events.append(v)
                     events.append(v)
                 elif isinstance(v, Callable):
                 elif isinstance(v, Callable):
                     # Call the lambda to get the event chain.
                     # Call the lambda to get the event chain.
-                    events.extend(call_event_fn(v, arg_spec))  # type: ignore
+                    events.extend(call_event_fn(v, args_spec))
                 else:
                 else:
                     raise ValueError(f"Invalid event: {v}")
                     raise ValueError(f"Invalid event: {v}")
 
 
         # If the input is a callable, create an event chain.
         # If the input is a callable, create an event chain.
         elif isinstance(value, Callable):
         elif isinstance(value, Callable):
-            events = call_event_fn(value, arg_spec)  # type: ignore
+            events = call_event_fn(value, args_spec)
 
 
         # Otherwise, raise an error.
         # Otherwise, raise an error.
         else:
         else:
@@ -435,7 +435,7 @@ class Component(BaseComponent, ABC):
             event_actions.update(e.event_actions)
             event_actions.update(e.event_actions)
 
 
         # Return the event chain.
         # Return the event chain.
-        if isinstance(arg_spec, Var):
+        if isinstance(args_spec, Var):
             return EventChain(
             return EventChain(
                 events=events,
                 events=events,
                 args_spec=None,
                 args_spec=None,
@@ -444,7 +444,7 @@ class Component(BaseComponent, ABC):
         else:
         else:
             return EventChain(
             return EventChain(
                 events=events,
                 events=events,
-                args_spec=arg_spec,  # type: ignore
+                args_spec=args_spec,
                 event_actions=event_actions,
                 event_actions=event_actions,
             )
             )
 
 
@@ -454,7 +454,7 @@ class Component(BaseComponent, ABC):
         Returns:
         Returns:
             The event triggers.
             The event triggers.
         """
         """
-        return {
+        default_triggers = {
             EventTriggers.ON_FOCUS: lambda: [],
             EventTriggers.ON_FOCUS: lambda: [],
             EventTriggers.ON_BLUR: lambda: [],
             EventTriggers.ON_BLUR: lambda: [],
             EventTriggers.ON_CLICK: lambda: [],
             EventTriggers.ON_CLICK: lambda: [],
@@ -471,6 +471,14 @@ class Component(BaseComponent, ABC):
             EventTriggers.ON_MOUNT: lambda: [],
             EventTriggers.ON_MOUNT: lambda: [],
             EventTriggers.ON_UNMOUNT: lambda: [],
             EventTriggers.ON_UNMOUNT: lambda: [],
         }
         }
+        # Look for component specific triggers,
+        # e.g. variable declared as EventHandler types.
+        for field in self.get_fields().values():
+            if types._issubclass(field.type_, EventHandler):
+                default_triggers[field.name] = getattr(
+                    field.type_, "args_spec", lambda: []
+                )
+        return default_triggers
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         """Represent the component in React.
         """Represent the component in React.
@@ -1352,6 +1360,9 @@ class CustomComponent(Component):
         # Set the tag to the name of the function.
         # Set the tag to the name of the function.
         self.tag = format.to_title_case(self.component_fn.__name__)
         self.tag = format.to_title_case(self.component_fn.__name__)
 
 
+        # Get the event triggers defined in the component declaration.
+        event_triggers_in_component_declaration = self.get_event_triggers()
+
         # Set the props.
         # Set the props.
         props = typing.get_type_hints(self.component_fn)
         props = typing.get_type_hints(self.component_fn)
         for key, value in kwargs.items():
         for key, value in kwargs.items():
@@ -1364,7 +1375,12 @@ class CustomComponent(Component):
 
 
             # Handle event chains.
             # Handle event chains.
             if types._issubclass(type_, EventChain):
             if types._issubclass(type_, EventChain):
-                value = self._create_event_chain(key, value)
+                value = self._create_event_chain(
+                    value=value,
+                    args_spec=event_triggers_in_component_declaration.get(
+                        key, lambda: []
+                    ),
+                )
                 self.props[format.to_camel_case(key)] = value
                 self.props[format.to_camel_case(key)] = value
                 continue
                 continue
 
 

+ 18 - 1
reflex/event.py

@@ -1,4 +1,5 @@
 """Define event classes to connect the frontend and backend."""
 """Define event classes to connect the frontend and backend."""
+
 from __future__ import annotations
 from __future__ import annotations
 
 
 import inspect
 import inspect
@@ -12,6 +13,7 @@ from typing import (
     Optional,
     Optional,
     Tuple,
     Tuple,
     Union,
     Union,
+    _GenericAlias,  # type: ignore
     get_type_hints,
     get_type_hints,
 )
 )
 
 
@@ -106,7 +108,7 @@ class EventHandler(EventActionsMixin):
     fn: Any
     fn: Any
 
 
     # The full name of the state class this event handler is attached to.
     # The full name of the state class this event handler is attached to.
-    # Emtpy string means this event handler is a server side event.
+    # Empty string means this event handler is a server side event.
     state_full_name: str = ""
     state_full_name: str = ""
 
 
     class Config:
     class Config:
@@ -115,6 +117,21 @@ class EventHandler(EventActionsMixin):
         # Needed to allow serialization of Callable.
         # Needed to allow serialization of Callable.
         frozen = True
         frozen = True
 
 
+    @classmethod
+    def __class_getitem__(cls, args_spec: str) -> _GenericAlias:
+        """Get a typed EventHandler.
+
+        Args:
+            args_spec: The args_spec of the EventHandler.
+
+        Returns:
+            The EventHandler class item.
+        """
+        gen = _GenericAlias(cls, Any)
+        # Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute.
+        gen.args_spec = args_spec
+        return gen
+
     @property
     @property
     def is_background(self) -> bool:
     def is_background(self) -> bool:
         """Whether the event handler is a background task.
         """Whether the event handler is a background task.

+ 32 - 0
tests/components/test_component.py

@@ -1350,3 +1350,35 @@ def test_custom_component_add_imports(tags):
 
 
     assert baseline.get_imports() == {"react": _list_to_import_vars(tags)}
     assert baseline.get_imports() == {"react": _list_to_import_vars(tags)}
     assert test.get_imports() == baseline.get_imports()
     assert test.get_imports() == baseline.get_imports()
+
+
+def test_custom_component_declare_event_handlers_in_fields():
+    class ReferenceComponent(Component):
+        def get_event_triggers(self) -> Dict[str, Any]:
+            """Test controlled triggers.
+
+            Returns:
+                Test controlled triggers.
+            """
+            return {
+                **super().get_event_triggers(),
+                "on_a": lambda e: [e],
+                "on_b": lambda e: [e.target.value],
+                "on_c": lambda e: [],
+                "on_d": lambda: [],
+                "on_e": lambda: [],
+            }
+
+    class TestComponent(Component):
+        on_a: EventHandler[lambda e0: [e0]]
+        on_b: EventHandler[lambda e0: [e0.target.value]]
+        on_c: EventHandler[lambda e0: []]
+        on_d: EventHandler[lambda: []]
+        on_e: EventHandler
+
+    custom_component = ReferenceComponent.create()
+    test_component = TestComponent.create()
+    assert (
+        custom_component.get_event_triggers().keys()
+        == test_component.get_event_triggers().keys()
+    )

+ 37 - 0
tests/components/test_component_future_annotations.py

@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from typing import Any
+
+from reflex.components.component import Component
+from reflex.event import EventHandler
+
+
+# This is a repeat of its namesake in test_component.py.
+def test_custom_component_declare_event_handlers_in_fields():
+    class ReferenceComponent(Component):
+        def get_event_triggers(self) -> dict[str, Any]:
+            """Test controlled triggers.
+
+            Returns:
+                Test controlled triggers.
+            """
+            return {
+                **super().get_event_triggers(),
+                "on_a": lambda e: [e],
+                "on_b": lambda e: [e.target.value],
+                "on_c": lambda e: [],
+                "on_d": lambda: [],
+            }
+
+    class TestComponent(Component):
+        on_a: EventHandler[lambda e0: [e0]]
+        on_b: EventHandler[lambda e0: [e0.target.value]]
+        on_c: EventHandler[lambda e0: []]
+        on_d: EventHandler[lambda: []]
+
+    custom_component = ReferenceComponent.create()
+    test_component = TestComponent.create()
+    assert (
+        custom_component.get_event_triggers().keys()
+        == test_component.get_event_triggers().keys()
+    )