浏览代码

[REF-2273] Implement .setvar special EventHandler (#3163)

* Allow EventHandler args to be partially applied

When an EventHandler is called with an incomplete set of args it creates a
partial EventSpec. This change allows Component._create_event_chain to apply
remaining args from an args_spec to an existing EventSpec to make it
functional.

Instead of requiring the use of `lambda` functions to pass arguments to an
EventHandler, they can now be passed directly and any remaining args defined in
the event trigger will be applied after those.

* [REF-2273] Implement `.setvar` special EventHandler

All State subclasses will now have a special `setvar` EventHandler which
appears in the autocomplete drop down, passes static analysis, and canbe used
to set State Vars in response to event triggers.

Before:
    rx.input(value=State.a, on_change=State.set_a)

After:
    rx.input(value=State.a, on_change=State.setvar("a"))

This reduces the "magic" because `setvar` is statically defined on all State
subclasses.

* Catch invalid Var names and types at compile time

* Add test cases for State.setvar

* Use a proper redis-compatible token
Masen Furer 1 年之前
父节点
当前提交
c636c91c9c
共有 5 个文件被更改,包括 186 次插入31 次删除
  1. 1 4
      reflex/components/component.py
  2. 43 26
      reflex/event.py
  3. 68 0
      reflex/state.py
  4. 36 1
      tests/test_event.py
  5. 38 0
      tests/test_state.py

+ 1 - 4
reflex/components/component.py

@@ -506,7 +506,7 @@ class Component(BaseComponent, ABC):
         if isinstance(value, List):
             events: list[EventSpec] = []
             for v in value:
-                if isinstance(v, EventHandler):
+                if isinstance(v, (EventHandler, EventSpec)):
                     # Call the event handler to get the event.
                     try:
                         event = call_event_handler(v, args_spec)
@@ -517,9 +517,6 @@ class Component(BaseComponent, ABC):
 
                     # Add the event to the chain.
                     events.append(event)
-                elif isinstance(v, EventSpec):
-                    # Add the event to the chain.
-                    events.append(v)
                 elif isinstance(v, Callable):
                     # Call the lambda to get the event chain.
                     events.extend(call_event_fn(v, args_spec))

+ 43 - 26
reflex/event.py

@@ -18,7 +18,7 @@ from typing import (
 
 from reflex import constants
 from reflex.base import Base
-from reflex.utils import console, format
+from reflex.utils import format
 from reflex.utils.types import ArgsSpec
 from reflex.vars import BaseVar, Var
 
@@ -168,7 +168,7 @@ class EventHandler(EventActionsMixin):
         """
         return getattr(self.fn, BACKGROUND_TASK_MARKER, False)
 
-    def __call__(self, *args: Var) -> EventSpec:
+    def __call__(self, *args: Any) -> EventSpec:
         """Pass arguments to the handler to get an event spec.
 
         This method configures event handlers that take in arguments.
@@ -246,6 +246,34 @@ class EventSpec(EventActionsMixin):
             event_actions=self.event_actions.copy(),
         )
 
+    def add_args(self, *args: Var) -> EventSpec:
+        """Add arguments to the event spec.
+
+        Args:
+            *args: The arguments to add positionally.
+
+        Returns:
+            The event spec with the new arguments.
+
+        Raises:
+            TypeError: If the arguments are invalid.
+        """
+        # Get the remaining unfilled function args.
+        fn_args = inspect.getfullargspec(self.handler.fn).args[1 + len(self.args) :]
+        fn_args = (Var.create_safe(arg) for arg in fn_args)
+
+        # Construct the payload.
+        values = []
+        for arg in args:
+            try:
+                values.append(Var.create(arg, _var_is_string=isinstance(arg, str)))
+            except TypeError as e:
+                raise TypeError(
+                    f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}."
+                ) from e
+        new_payload = tuple(zip(fn_args, values))
+        return self.with_args(self.args + new_payload)
+
 
 class CallableEventSpec(EventSpec):
     """Decorate an EventSpec-returning function to act as both a EventSpec and a function.
@@ -732,7 +760,8 @@ def get_hydrate_event(state) -> str:
 
 
 def call_event_handler(
-    event_handler: EventHandler, arg_spec: Union[Var, ArgsSpec]
+    event_handler: EventHandler | EventSpec,
+    arg_spec: ArgsSpec,
 ) -> EventSpec:
     """Call an event handler to get the event spec.
 
@@ -750,33 +779,21 @@ def call_event_handler(
     Returns:
         The event spec from calling the event handler.
     """
-    args = inspect.getfullargspec(event_handler.fn).args
+    parsed_args = parse_args_spec(arg_spec)  # type: ignore
 
-    # handle new API using lambda to define triggers
-    if isinstance(arg_spec, ArgsSpec):
-        parsed_args = parse_args_spec(arg_spec)  # type: ignore
+    if isinstance(event_handler, EventSpec):
+        # Handle partial application of EventSpec args
+        return event_handler.add_args(*parsed_args)
 
-        if len(args) == len(["self", *parsed_args]):
-            return event_handler(*parsed_args)  # type: ignore
-        else:
-            source = inspect.getsource(arg_spec)  # type: ignore
-            raise ValueError(
-                f"number of arguments in {event_handler.fn.__qualname__} "
-                f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'"
-            )
+    args = inspect.getfullargspec(event_handler.fn).args
+    if len(args) == len(["self", *parsed_args]):
+        return event_handler(*parsed_args)  # type: ignore
     else:
-        console.deprecate(
-            feature_name="EVENT_ARG API for triggers",
-            reason="Replaced by new API using lambda allow arbitrary number of args",
-            deprecation_version="0.2.8",
-            removal_version="0.5.0",
+        source = inspect.getsource(arg_spec)  # type: ignore
+        raise ValueError(
+            f"number of arguments in {event_handler.fn.__qualname__} "
+            f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'"
         )
-        if len(args) == 1:
-            return event_handler()
-        assert (
-            len(args) == 2
-        ), f"Event handler {event_handler.fn} must have 1 or 2 arguments."
-        return event_handler(arg_spec)  # type: ignore
 
 
 def parse_args_spec(arg_spec: ArgsSpec):

+ 68 - 0
reflex/state.py

@@ -247,6 +247,60 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]:
     return token, state_name
 
 
+class EventHandlerSetVar(EventHandler):
+    """A special event handler to wrap setvar functionality."""
+
+    state_cls: Type[BaseState]
+
+    def __init__(self, state_cls: Type[BaseState]):
+        """Initialize the EventHandlerSetVar.
+
+        Args:
+            state_cls: The state class that vars will be set on.
+        """
+        super().__init__(
+            fn=type(self).setvar,
+            state_full_name=state_cls.get_full_name(),
+            state_cls=state_cls,  # type: ignore
+        )
+
+    def setvar(self, var_name: str, value: Any):
+        """Set the state variable to the value of the event.
+
+        Note: `self` here will be an instance of the state, not EventHandlerSetVar.
+
+        Args:
+            var_name: The name of the variable to set.
+            value: The value to set the variable to.
+        """
+        getattr(self, constants.SETTER_PREFIX + var_name)(value)
+
+    def __call__(self, *args: Any) -> EventSpec:
+        """Performs pre-checks and munging on the provided args that will become an EventSpec.
+
+        Args:
+            *args: The event args.
+
+        Returns:
+            The (partial) EventSpec that will be used to create the event to setvar.
+
+        Raises:
+            AttributeError: If the given Var name does not exist on the state.
+            ValueError: If the given Var name is not a str
+        """
+        if args:
+            if not isinstance(args[0], str):
+                raise ValueError(
+                    f"Var name must be passed as a string, got {args[0]!r}"
+                )
+            # Check that the requested Var setter exists on the State at compile time.
+            if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None:
+                raise AttributeError(
+                    f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`"
+                )
+        return super().__call__(*args)
+
+
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
 
@@ -310,6 +364,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # Whether the state has ever been touched since instantiation.
     _was_touched: bool = False
 
+    # A special event handler for setting base vars.
+    setvar: ClassVar[EventHandler]
+
     def __init__(
         self,
         *args,
@@ -500,6 +557,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 value.__qualname__ = f"{cls.__name__}.{name}"
                 events[name] = value
 
+        # Create the setvar event handler for this state
+        cls._create_setvar()
+
         for name, fn in events.items():
             handler = cls._create_event_handler(fn)
             cls.event_handlers[name] = handler
@@ -833,6 +893,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         return EventHandler(fn=fn, state_full_name=cls.get_full_name())
 
+    @classmethod
+    def _create_setvar(cls):
+        """Create the setvar method for the state."""
+        cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls)
+
     @classmethod
     def _create_setter(cls, prop: BaseVar):
         """Create a setter for the var.
@@ -1800,6 +1865,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return state
 
 
+EventHandlerSetVar.update_forward_refs()
+
+
 class State(BaseState):
     """The app Base State."""
 

+ 36 - 1
tests/test_event.py

@@ -1,9 +1,10 @@
 import json
+from typing import List
 
 import pytest
 
 from reflex import event
-from reflex.event import Event, EventHandler, EventSpec, fix_events
+from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events
 from reflex.state import BaseState
 from reflex.utils import format
 from reflex.vars import Var
@@ -91,6 +92,40 @@ def test_call_event_handler():
         handler(test_fn)  # type: ignore
 
 
+def test_call_event_handler_partial():
+    """Calling an EventHandler with incomplete args returns an EventSpec that can be extended."""
+
+    def test_fn_with_args(_, arg1, arg2):
+        pass
+
+    test_fn_with_args.__qualname__ = "test_fn_with_args"
+
+    def spec(a2: str) -> List[str]:
+        return [a2]
+
+    handler = EventHandler(fn=test_fn_with_args)
+    event_spec = handler(make_var("first"))
+    event_spec2 = call_event_handler(event_spec, spec)
+
+    assert event_spec.handler == handler
+    assert len(event_spec.args) == 1
+    assert event_spec.args[0][0].equals(Var.create_safe("arg1"))
+    assert event_spec.args[0][1].equals(Var.create_safe("first"))
+    assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})'
+
+    assert event_spec2 is not event_spec
+    assert event_spec2.handler == handler
+    assert len(event_spec2.args) == 2
+    assert event_spec2.args[0][0].equals(Var.create_safe("arg1"))
+    assert event_spec2.args[0][1].equals(Var.create_safe("first"))
+    assert event_spec2.args[1][0].equals(Var.create_safe("arg2"))
+    assert event_spec2.args[1][1].equals(Var.create_safe("_a2"))
+    assert (
+        format.format_event(event_spec2)
+        == 'Event("test_fn_with_args", {arg1:first,arg2:_a2})'
+    )
+
+
 @pytest.mark.parametrize(
     ("arg1", "arg2"),
     (

+ 38 - 0
tests/test_state.py

@@ -2845,3 +2845,41 @@ def test_potentially_dirty_substates():
     assert RxState._potentially_dirty_substates() == {State}
     assert State._potentially_dirty_substates() == {C1}
     assert C1._potentially_dirty_substates() == set()
+
+
+@pytest.mark.asyncio
+async def test_setvar(mock_app: rx.App, token: str):
+    """Test that setvar works correctly.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+    state = await mock_app.state_manager.get_state(_substate_key(token, TestState))
+
+    # Set Var in same state (with Var type casting)
+    for event in rx.event.fix_events(
+        [TestState.setvar("num1", 42), TestState.setvar("num2", "4.2")], token
+    ):
+        async for update in state._process(event):
+            print(update)
+    assert state.num1 == 42
+    assert state.num2 == 4.2
+
+    # Set Var in parent state
+    for event in rx.event.fix_events([GrandchildState.setvar("array", [43])], token):
+        async for update in state._process(event):
+            print(update)
+    assert state.array == [43]
+
+    # Cannot setvar for non-existant var
+    with pytest.raises(AttributeError):
+        TestState.setvar("non_existant_var")
+
+    # Cannot setvar for computed vars
+    with pytest.raises(AttributeError):
+        TestState.setvar("sum")
+
+    # Cannot setvar with non-string
+    with pytest.raises(ValueError):
+        TestState.setvar(42, 42)