Sfoglia il codice sorgente

decentralized event handlers (#5227)

* decentralized event handlers

* make the name extra complicated to make sure no clash can happen between different guys
Khaleel Al-Adhami 2 settimane fa
parent
commit
5acd50ec5a
3 ha cambiato i file con 84 aggiunte e 0 eliminazioni
  1. 24 0
      reflex/event.py
  2. 16 0
      reflex/state.py
  3. 44 0
      tests/units/test_event.py

+ 24 - 0
reflex/event.py

@@ -2066,6 +2066,30 @@ class EventNamespace:
                 setattr(func, BACKGROUND_TASK_MARKER, True)
             if getattr(func, "__name__", "").startswith("_"):
                 raise ValueError("Event handlers cannot be private.")
+
+            qualname: str | None = getattr(func, "__qualname__", None)
+
+            if qualname and (
+                len(func_path := qualname.split(".")) == 1
+                or func_path[-2] == "<locals>"
+            ):
+                from reflex.state import BaseState
+
+                types = get_type_hints(func)
+                state_arg_name = next(iter(inspect.signature(func).parameters), None)
+                state_cls = state_arg_name and types.get(state_arg_name)
+                if state_cls and issubclass(state_cls, BaseState):
+                    name = (
+                        (func.__module__ + "." + qualname)
+                        .replace(".", "_")
+                        .replace("<locals>", "_")
+                        .removeprefix("_")
+                    )
+                    object.__setattr__(func, "__name__", name)
+                    object.__setattr__(func, "__qualname__", name)
+                    state_cls._add_event_handler(name, func)
+                    return getattr(state_cls, name)
+
             return func  # pyright: ignore [reportReturnType]
 
         if func is not None:

+ 16 - 0
reflex/state.py

@@ -627,6 +627,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         all_base_state_classes[cls.get_full_name()] = None
 
+    @classmethod
+    def _add_event_handler(
+        cls,
+        name: str,
+        fn: Callable,
+    ):
+        """Add an event handler dynamically to the state.
+
+        Args:
+            name: The name of the event handler.
+            fn: The function to call when the event is triggered.
+        """
+        handler = cls._create_event_handler(fn)
+        cls.event_handlers[name] = handler
+        setattr(cls, name, handler)
+
     @staticmethod
     def _copy_fn(fn: Callable) -> Callable:
         """Copy a function. Used to copy ComputedVars and EventHandlers from mixins.

+ 44 - 0
tests/units/test_event.py

@@ -483,3 +483,47 @@ def test_event_bound_method() -> None:
 
     w = Wrapper()
     _ = rx.input(on_change=w.get_handler)
+
+
+def test_decentralized_event_with_args():
+    """Test the decentralized event."""
+
+    class S(BaseState):
+        field: Field[str] = field("")
+
+    @event
+    def e(s: S, arg: str):
+        s.field = arg
+
+    _ = rx.input(on_change=e("foo"))
+
+
+def test_decentralized_event_no_args():
+    """Test the decentralized event with no args."""
+
+    class S(BaseState):
+        field: Field[str] = field("")
+
+    @event
+    def e(s: S):
+        s.field = "foo"
+
+    _ = rx.input(on_change=e())
+    _ = rx.input(on_change=e)
+
+
+class GlobalState(BaseState):
+    """Global state for testing decentralized events."""
+
+    field: Field[str] = field("")
+
+
+@event
+def f(s: GlobalState, arg: str):
+    s.field = arg
+
+
+def test_decentralized_event_global_state():
+    """Test the decentralized event with a global state."""
+    _ = rx.input(on_change=f("foo"))
+    _ = rx.input(on_change=f)