浏览代码

[REF-2272] Support declaring EventHandlers directly in component (#2952)

Martin Xu 1 年之前
父节点
当前提交
f372402ee4

+ 1 - 0
reflex/__init__.py

@@ -131,6 +131,7 @@ _MAPPING = {
     "reflex.event": [
         "event",
         "EventChain",
+        "EventHandler",
         "background",
         "call_script",
         "clear_local_storage",

+ 1 - 0
reflex/__init__.pyi

@@ -113,6 +113,7 @@ from reflex import constants as constants
 from reflex.constants import Env as Env
 from reflex import event as event
 from reflex.event import EventChain as EventChain
+from reflex.event import EventHandler as EventHandler
 from reflex.event import background as background
 from reflex.event import call_script as call_script
 from reflex.event import clear_local_storage as clear_local_storage

+ 32 - 16
reflex/components/component.py

@@ -236,6 +236,8 @@ class Component(BaseComponent, ABC):
             if types._issubclass(field.type_, Var):
                 field.required = False
                 field.default = Var.create(field.default)
+            elif types._issubclass(field.type_, EventHandler):
+                field.required = False
 
         # Ensure renamed props from parent classes are applied to the subclass.
         if cls._rename_props:
@@ -272,7 +274,8 @@ class Component(BaseComponent, ABC):
 
         # Get the component fields, triggers, and props.
         fields = self.get_fields()
-        triggers = self.get_event_triggers().keys()
+        component_specific_triggers = self.get_event_triggers()
+        triggers = component_specific_triggers.keys()
         props = self.get_props()
 
         # Add any events triggers.
@@ -327,7 +330,9 @@ class Component(BaseComponent, ABC):
             # Check if the key is an event trigger.
             if key in triggers:
                 # Temporarily disable full control for event triggers.
-                kwargs["event_triggers"][key] = self._create_event_chain(key, value)
+                kwargs["event_triggers"][key] = self._create_event_chain(
+                    value=value, args_spec=component_specific_triggers[key]
+                )
 
         # Remove any keys that were added as events.
         for key in kwargs["event_triggers"]:
@@ -359,7 +364,7 @@ class Component(BaseComponent, ABC):
 
     def _create_event_chain(
         self,
-        event_trigger: str,
+        args_spec: Any,
         value: Union[
             Var, EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], Callable
         ],
@@ -367,7 +372,7 @@ class Component(BaseComponent, ABC):
         """Create an event chain from a variety of input types.
 
         Args:
-            event_trigger: The event trigger to bind the chain to.
+            args_spec: The args_spec of the the event trigger being bound.
             value: The value to create the event chain from.
 
         Returns:
@@ -376,9 +381,6 @@ class Component(BaseComponent, ABC):
         Raises:
             ValueError: If the value is not a valid event chain.
         """
-        # Check if the trigger is a controlled event.
-        triggers = self.get_event_triggers()
-
         # If it's an event chain var, return it.
         if isinstance(value, Var):
             if value._var_type is not EventChain:
@@ -388,8 +390,6 @@ class Component(BaseComponent, ABC):
             # Trust that the caller knows what they're doing passing an EventChain directly
             return value
 
-        arg_spec = triggers.get(event_trigger, lambda: [])
-
         # If the input is a single event handler, wrap it in a list.
         if isinstance(value, (EventHandler, EventSpec)):
             value = [value]
@@ -401,7 +401,7 @@ class Component(BaseComponent, ABC):
                 if isinstance(v, EventHandler):
                     # Call the event handler to get the event.
                     try:
-                        event = call_event_handler(v, arg_spec)  # type: ignore
+                        event = call_event_handler(v, args_spec)
                     except ValueError as err:
                         raise ValueError(
                             f" {err} defined in the `{type(self).__name__}` component"
@@ -414,13 +414,13 @@ class Component(BaseComponent, ABC):
                     events.append(v)
                 elif isinstance(v, Callable):
                     # Call the lambda to get the event chain.
-                    events.extend(call_event_fn(v, arg_spec))  # type: ignore
+                    events.extend(call_event_fn(v, args_spec))
                 else:
                     raise ValueError(f"Invalid event: {v}")
 
         # If the input is a callable, create an event chain.
         elif isinstance(value, Callable):
-            events = call_event_fn(value, arg_spec)  # type: ignore
+            events = call_event_fn(value, args_spec)
 
         # Otherwise, raise an error.
         else:
@@ -435,7 +435,7 @@ class Component(BaseComponent, ABC):
             event_actions.update(e.event_actions)
 
         # Return the event chain.
-        if isinstance(arg_spec, Var):
+        if isinstance(args_spec, Var):
             return EventChain(
                 events=events,
                 args_spec=None,
@@ -444,7 +444,7 @@ class Component(BaseComponent, ABC):
         else:
             return EventChain(
                 events=events,
-                args_spec=arg_spec,  # type: ignore
+                args_spec=args_spec,
                 event_actions=event_actions,
             )
 
@@ -454,7 +454,7 @@ class Component(BaseComponent, ABC):
         Returns:
             The event triggers.
         """
-        return {
+        default_triggers = {
             EventTriggers.ON_FOCUS: lambda: [],
             EventTriggers.ON_BLUR: lambda: [],
             EventTriggers.ON_CLICK: lambda: [],
@@ -471,6 +471,14 @@ class Component(BaseComponent, ABC):
             EventTriggers.ON_MOUNT: lambda: [],
             EventTriggers.ON_UNMOUNT: lambda: [],
         }
+        # Look for component specific triggers,
+        # e.g. variable declared as EventHandler types.
+        for field in self.get_fields().values():
+            if types._issubclass(field.type_, EventHandler):
+                default_triggers[field.name] = getattr(
+                    field.type_, "args_spec", lambda: []
+                )
+        return default_triggers
 
     def __repr__(self) -> str:
         """Represent the component in React.
@@ -1352,6 +1360,9 @@ class CustomComponent(Component):
         # Set the tag to the name of the function.
         self.tag = format.to_title_case(self.component_fn.__name__)
 
+        # Get the event triggers defined in the component declaration.
+        event_triggers_in_component_declaration = self.get_event_triggers()
+
         # Set the props.
         props = typing.get_type_hints(self.component_fn)
         for key, value in kwargs.items():
@@ -1364,7 +1375,12 @@ class CustomComponent(Component):
 
             # Handle event chains.
             if types._issubclass(type_, EventChain):
-                value = self._create_event_chain(key, value)
+                value = self._create_event_chain(
+                    value=value,
+                    args_spec=event_triggers_in_component_declaration.get(
+                        key, lambda: []
+                    ),
+                )
                 self.props[format.to_camel_case(key)] = value
                 continue
 

+ 18 - 1
reflex/event.py

@@ -1,4 +1,5 @@
 """Define event classes to connect the frontend and backend."""
+
 from __future__ import annotations
 
 import inspect
@@ -12,6 +13,7 @@ from typing import (
     Optional,
     Tuple,
     Union,
+    _GenericAlias,  # type: ignore
     get_type_hints,
 )
 
@@ -106,7 +108,7 @@ class EventHandler(EventActionsMixin):
     fn: Any
 
     # The full name of the state class this event handler is attached to.
-    # Emtpy string means this event handler is a server side event.
+    # Empty string means this event handler is a server side event.
     state_full_name: str = ""
 
     class Config:
@@ -115,6 +117,21 @@ class EventHandler(EventActionsMixin):
         # Needed to allow serialization of Callable.
         frozen = True
 
+    @classmethod
+    def __class_getitem__(cls, args_spec: str) -> _GenericAlias:
+        """Get a typed EventHandler.
+
+        Args:
+            args_spec: The args_spec of the EventHandler.
+
+        Returns:
+            The EventHandler class item.
+        """
+        gen = _GenericAlias(cls, Any)
+        # Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute.
+        gen.args_spec = args_spec
+        return gen
+
     @property
     def is_background(self) -> bool:
         """Whether the event handler is a background task.

+ 32 - 0
tests/components/test_component.py

@@ -1350,3 +1350,35 @@ def test_custom_component_add_imports(tags):
 
     assert baseline.get_imports() == {"react": _list_to_import_vars(tags)}
     assert test.get_imports() == baseline.get_imports()
+
+
+def test_custom_component_declare_event_handlers_in_fields():
+    class ReferenceComponent(Component):
+        def get_event_triggers(self) -> Dict[str, Any]:
+            """Test controlled triggers.
+
+            Returns:
+                Test controlled triggers.
+            """
+            return {
+                **super().get_event_triggers(),
+                "on_a": lambda e: [e],
+                "on_b": lambda e: [e.target.value],
+                "on_c": lambda e: [],
+                "on_d": lambda: [],
+                "on_e": lambda: [],
+            }
+
+    class TestComponent(Component):
+        on_a: EventHandler[lambda e0: [e0]]
+        on_b: EventHandler[lambda e0: [e0.target.value]]
+        on_c: EventHandler[lambda e0: []]
+        on_d: EventHandler[lambda: []]
+        on_e: EventHandler
+
+    custom_component = ReferenceComponent.create()
+    test_component = TestComponent.create()
+    assert (
+        custom_component.get_event_triggers().keys()
+        == test_component.get_event_triggers().keys()
+    )

+ 37 - 0
tests/components/test_component_future_annotations.py

@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from typing import Any
+
+from reflex.components.component import Component
+from reflex.event import EventHandler
+
+
+# This is a repeat of its namesake in test_component.py.
+def test_custom_component_declare_event_handlers_in_fields():
+    class ReferenceComponent(Component):
+        def get_event_triggers(self) -> dict[str, Any]:
+            """Test controlled triggers.
+
+            Returns:
+                Test controlled triggers.
+            """
+            return {
+                **super().get_event_triggers(),
+                "on_a": lambda e: [e],
+                "on_b": lambda e: [e.target.value],
+                "on_c": lambda e: [],
+                "on_d": lambda: [],
+            }
+
+    class TestComponent(Component):
+        on_a: EventHandler[lambda e0: [e0]]
+        on_b: EventHandler[lambda e0: [e0.target.value]]
+        on_c: EventHandler[lambda e0: []]
+        on_d: EventHandler[lambda: []]
+
+    custom_component = ReferenceComponent.create()
+    test_component = TestComponent.create()
+    assert (
+        custom_component.get_event_triggers().keys()
+        == test_component.get_event_triggers().keys()
+    )