Browse Source

[REF-2273] Implement .setvar special EventHandler (#3163)

* Allow EventHandler args to be partially applied

When an EventHandler is called with an incomplete set of args it creates a
partial EventSpec. This change allows Component._create_event_chain to apply
remaining args from an args_spec to an existing EventSpec to make it
functional.

Instead of requiring the use of `lambda` functions to pass arguments to an
EventHandler, they can now be passed directly and any remaining args defined in
the event trigger will be applied after those.

* [REF-2273] Implement `.setvar` special EventHandler

All State subclasses will now have a special `setvar` EventHandler which
appears in the autocomplete drop down, passes static analysis, and canbe used
to set State Vars in response to event triggers.

Before:
    rx.input(value=State.a, on_change=State.set_a)

After:
    rx.input(value=State.a, on_change=State.setvar("a"))

This reduces the "magic" because `setvar` is statically defined on all State
subclasses.

* Catch invalid Var names and types at compile time

* Add test cases for State.setvar

* Use a proper redis-compatible token
Masen Furer 1 year ago
parent
commit
c636c91c9c
5 changed files with 186 additions and 31 deletions
  1. 1 4
      reflex/components/component.py
  2. 43 26
      reflex/event.py
  3. 68 0
      reflex/state.py
  4. 36 1
      tests/test_event.py
  5. 38 0
      tests/test_state.py

+ 1 - 4
reflex/components/component.py

@@ -506,7 +506,7 @@ class Component(BaseComponent, ABC):
         if isinstance(value, List):
         if isinstance(value, List):
             events: list[EventSpec] = []
             events: list[EventSpec] = []
             for v in value:
             for v in value:
-                if isinstance(v, EventHandler):
+                if isinstance(v, (EventHandler, EventSpec)):
                     # Call the event handler to get the event.
                     # Call the event handler to get the event.
                     try:
                     try:
                         event = call_event_handler(v, args_spec)
                         event = call_event_handler(v, args_spec)
@@ -517,9 +517,6 @@ class Component(BaseComponent, ABC):
 
 
                     # Add the event to the chain.
                     # Add the event to the chain.
                     events.append(event)
                     events.append(event)
-                elif isinstance(v, EventSpec):
-                    # Add the event to the chain.
-                    events.append(v)
                 elif isinstance(v, Callable):
                 elif isinstance(v, Callable):
                     # Call the lambda to get the event chain.
                     # Call the lambda to get the event chain.
                     events.extend(call_event_fn(v, args_spec))
                     events.extend(call_event_fn(v, args_spec))

+ 43 - 26
reflex/event.py

@@ -18,7 +18,7 @@ from typing import (
 
 
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
-from reflex.utils import console, format
+from reflex.utils import format
 from reflex.utils.types import ArgsSpec
 from reflex.utils.types import ArgsSpec
 from reflex.vars import BaseVar, Var
 from reflex.vars import BaseVar, Var
 
 
@@ -168,7 +168,7 @@ class EventHandler(EventActionsMixin):
         """
         """
         return getattr(self.fn, BACKGROUND_TASK_MARKER, False)
         return getattr(self.fn, BACKGROUND_TASK_MARKER, False)
 
 
-    def __call__(self, *args: Var) -> EventSpec:
+    def __call__(self, *args: Any) -> EventSpec:
         """Pass arguments to the handler to get an event spec.
         """Pass arguments to the handler to get an event spec.
 
 
         This method configures event handlers that take in arguments.
         This method configures event handlers that take in arguments.
@@ -246,6 +246,34 @@ class EventSpec(EventActionsMixin):
             event_actions=self.event_actions.copy(),
             event_actions=self.event_actions.copy(),
         )
         )
 
 
+    def add_args(self, *args: Var) -> EventSpec:
+        """Add arguments to the event spec.
+
+        Args:
+            *args: The arguments to add positionally.
+
+        Returns:
+            The event spec with the new arguments.
+
+        Raises:
+            TypeError: If the arguments are invalid.
+        """
+        # Get the remaining unfilled function args.
+        fn_args = inspect.getfullargspec(self.handler.fn).args[1 + len(self.args) :]
+        fn_args = (Var.create_safe(arg) for arg in fn_args)
+
+        # Construct the payload.
+        values = []
+        for arg in args:
+            try:
+                values.append(Var.create(arg, _var_is_string=isinstance(arg, str)))
+            except TypeError as e:
+                raise TypeError(
+                    f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}."
+                ) from e
+        new_payload = tuple(zip(fn_args, values))
+        return self.with_args(self.args + new_payload)
+
 
 
 class CallableEventSpec(EventSpec):
 class CallableEventSpec(EventSpec):
     """Decorate an EventSpec-returning function to act as both a EventSpec and a function.
     """Decorate an EventSpec-returning function to act as both a EventSpec and a function.
@@ -732,7 +760,8 @@ def get_hydrate_event(state) -> str:
 
 
 
 
 def call_event_handler(
 def call_event_handler(
-    event_handler: EventHandler, arg_spec: Union[Var, ArgsSpec]
+    event_handler: EventHandler | EventSpec,
+    arg_spec: ArgsSpec,
 ) -> EventSpec:
 ) -> EventSpec:
     """Call an event handler to get the event spec.
     """Call an event handler to get the event spec.
 
 
@@ -750,33 +779,21 @@ def call_event_handler(
     Returns:
     Returns:
         The event spec from calling the event handler.
         The event spec from calling the event handler.
     """
     """
-    args = inspect.getfullargspec(event_handler.fn).args
+    parsed_args = parse_args_spec(arg_spec)  # type: ignore
 
 
-    # handle new API using lambda to define triggers
-    if isinstance(arg_spec, ArgsSpec):
-        parsed_args = parse_args_spec(arg_spec)  # type: ignore
+    if isinstance(event_handler, EventSpec):
+        # Handle partial application of EventSpec args
+        return event_handler.add_args(*parsed_args)
 
 
-        if len(args) == len(["self", *parsed_args]):
-            return event_handler(*parsed_args)  # type: ignore
-        else:
-            source = inspect.getsource(arg_spec)  # type: ignore
-            raise ValueError(
-                f"number of arguments in {event_handler.fn.__qualname__} "
-                f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'"
-            )
+    args = inspect.getfullargspec(event_handler.fn).args
+    if len(args) == len(["self", *parsed_args]):
+        return event_handler(*parsed_args)  # type: ignore
     else:
     else:
-        console.deprecate(
-            feature_name="EVENT_ARG API for triggers",
-            reason="Replaced by new API using lambda allow arbitrary number of args",
-            deprecation_version="0.2.8",
-            removal_version="0.5.0",
+        source = inspect.getsource(arg_spec)  # type: ignore
+        raise ValueError(
+            f"number of arguments in {event_handler.fn.__qualname__} "
+            f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'"
         )
         )
-        if len(args) == 1:
-            return event_handler()
-        assert (
-            len(args) == 2
-        ), f"Event handler {event_handler.fn} must have 1 or 2 arguments."
-        return event_handler(arg_spec)  # type: ignore
 
 
 
 
 def parse_args_spec(arg_spec: ArgsSpec):
 def parse_args_spec(arg_spec: ArgsSpec):

+ 68 - 0
reflex/state.py

@@ -247,6 +247,60 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]:
     return token, state_name
     return token, state_name
 
 
 
 
+class EventHandlerSetVar(EventHandler):
+    """A special event handler to wrap setvar functionality."""
+
+    state_cls: Type[BaseState]
+
+    def __init__(self, state_cls: Type[BaseState]):
+        """Initialize the EventHandlerSetVar.
+
+        Args:
+            state_cls: The state class that vars will be set on.
+        """
+        super().__init__(
+            fn=type(self).setvar,
+            state_full_name=state_cls.get_full_name(),
+            state_cls=state_cls,  # type: ignore
+        )
+
+    def setvar(self, var_name: str, value: Any):
+        """Set the state variable to the value of the event.
+
+        Note: `self` here will be an instance of the state, not EventHandlerSetVar.
+
+        Args:
+            var_name: The name of the variable to set.
+            value: The value to set the variable to.
+        """
+        getattr(self, constants.SETTER_PREFIX + var_name)(value)
+
+    def __call__(self, *args: Any) -> EventSpec:
+        """Performs pre-checks and munging on the provided args that will become an EventSpec.
+
+        Args:
+            *args: The event args.
+
+        Returns:
+            The (partial) EventSpec that will be used to create the event to setvar.
+
+        Raises:
+            AttributeError: If the given Var name does not exist on the state.
+            ValueError: If the given Var name is not a str
+        """
+        if args:
+            if not isinstance(args[0], str):
+                raise ValueError(
+                    f"Var name must be passed as a string, got {args[0]!r}"
+                )
+            # Check that the requested Var setter exists on the State at compile time.
+            if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None:
+                raise AttributeError(
+                    f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`"
+                )
+        return super().__call__(*args)
+
+
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
     """The state of the app."""
 
 
@@ -310,6 +364,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     # Whether the state has ever been touched since instantiation.
     # Whether the state has ever been touched since instantiation.
     _was_touched: bool = False
     _was_touched: bool = False
 
 
+    # A special event handler for setting base vars.
+    setvar: ClassVar[EventHandler]
+
     def __init__(
     def __init__(
         self,
         self,
         *args,
         *args,
@@ -500,6 +557,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 value.__qualname__ = f"{cls.__name__}.{name}"
                 value.__qualname__ = f"{cls.__name__}.{name}"
                 events[name] = value
                 events[name] = value
 
 
+        # Create the setvar event handler for this state
+        cls._create_setvar()
+
         for name, fn in events.items():
         for name, fn in events.items():
             handler = cls._create_event_handler(fn)
             handler = cls._create_event_handler(fn)
             cls.event_handlers[name] = handler
             cls.event_handlers[name] = handler
@@ -833,6 +893,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         """
         return EventHandler(fn=fn, state_full_name=cls.get_full_name())
         return EventHandler(fn=fn, state_full_name=cls.get_full_name())
 
 
+    @classmethod
+    def _create_setvar(cls):
+        """Create the setvar method for the state."""
+        cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls)
+
     @classmethod
     @classmethod
     def _create_setter(cls, prop: BaseVar):
     def _create_setter(cls, prop: BaseVar):
         """Create a setter for the var.
         """Create a setter for the var.
@@ -1800,6 +1865,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return state
         return state
 
 
 
 
+EventHandlerSetVar.update_forward_refs()
+
+
 class State(BaseState):
 class State(BaseState):
     """The app Base State."""
     """The app Base State."""
 
 

+ 36 - 1
tests/test_event.py

@@ -1,9 +1,10 @@
 import json
 import json
+from typing import List
 
 
 import pytest
 import pytest
 
 
 from reflex import event
 from reflex import event
-from reflex.event import Event, EventHandler, EventSpec, fix_events
+from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events
 from reflex.state import BaseState
 from reflex.state import BaseState
 from reflex.utils import format
 from reflex.utils import format
 from reflex.vars import Var
 from reflex.vars import Var
@@ -91,6 +92,40 @@ def test_call_event_handler():
         handler(test_fn)  # type: ignore
         handler(test_fn)  # type: ignore
 
 
 
 
+def test_call_event_handler_partial():
+    """Calling an EventHandler with incomplete args returns an EventSpec that can be extended."""
+
+    def test_fn_with_args(_, arg1, arg2):
+        pass
+
+    test_fn_with_args.__qualname__ = "test_fn_with_args"
+
+    def spec(a2: str) -> List[str]:
+        return [a2]
+
+    handler = EventHandler(fn=test_fn_with_args)
+    event_spec = handler(make_var("first"))
+    event_spec2 = call_event_handler(event_spec, spec)
+
+    assert event_spec.handler == handler
+    assert len(event_spec.args) == 1
+    assert event_spec.args[0][0].equals(Var.create_safe("arg1"))
+    assert event_spec.args[0][1].equals(Var.create_safe("first"))
+    assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})'
+
+    assert event_spec2 is not event_spec
+    assert event_spec2.handler == handler
+    assert len(event_spec2.args) == 2
+    assert event_spec2.args[0][0].equals(Var.create_safe("arg1"))
+    assert event_spec2.args[0][1].equals(Var.create_safe("first"))
+    assert event_spec2.args[1][0].equals(Var.create_safe("arg2"))
+    assert event_spec2.args[1][1].equals(Var.create_safe("_a2"))
+    assert (
+        format.format_event(event_spec2)
+        == 'Event("test_fn_with_args", {arg1:first,arg2:_a2})'
+    )
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     ("arg1", "arg2"),
     ("arg1", "arg2"),
     (
     (

+ 38 - 0
tests/test_state.py

@@ -2845,3 +2845,41 @@ def test_potentially_dirty_substates():
     assert RxState._potentially_dirty_substates() == {State}
     assert RxState._potentially_dirty_substates() == {State}
     assert State._potentially_dirty_substates() == {C1}
     assert State._potentially_dirty_substates() == {C1}
     assert C1._potentially_dirty_substates() == set()
     assert C1._potentially_dirty_substates() == set()
+
+
+@pytest.mark.asyncio
+async def test_setvar(mock_app: rx.App, token: str):
+    """Test that setvar works correctly.
+
+    Args:
+        mock_app: An app that will be returned by `get_app()`
+        token: A token.
+    """
+    state = await mock_app.state_manager.get_state(_substate_key(token, TestState))
+
+    # Set Var in same state (with Var type casting)
+    for event in rx.event.fix_events(
+        [TestState.setvar("num1", 42), TestState.setvar("num2", "4.2")], token
+    ):
+        async for update in state._process(event):
+            print(update)
+    assert state.num1 == 42
+    assert state.num2 == 4.2
+
+    # Set Var in parent state
+    for event in rx.event.fix_events([GrandchildState.setvar("array", [43])], token):
+        async for update in state._process(event):
+            print(update)
+    assert state.array == [43]
+
+    # Cannot setvar for non-existant var
+    with pytest.raises(AttributeError):
+        TestState.setvar("non_existant_var")
+
+    # Cannot setvar for computed vars
+    with pytest.raises(AttributeError):
+        TestState.setvar("sum")
+
+    # Cannot setvar with non-string
+    with pytest.raises(ValueError):
+        TestState.setvar(42, 42)