소스 검색

Expose DOM event actions on EventHandler, EventSpec, and EventChain (stopPropagation) (#1891)

* Expose preventDefault and stopPropagation for DOM events

All EventHandler, EventSpec, and EventChain can now carry these extra
"event_actions" that will be applied inside the frontend code when an event is
triggered from the DOM.

Fix #1621
Fix REF-675

* Test cases (and fixes) for "event_actions"

* form: from __future__ import annotations

for py38, py39 compat

* Revert overzealous merge conflict resolution
Masen Furer 1 년 전
부모
커밋
56476d0a86

+ 234 - 0
integration/test_event_actions.py

@@ -0,0 +1,234 @@
+"""Ensure stopPropagation and preventDefault work as expected."""
+
+from typing import Callable, Coroutine, Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import AppHarness, WebDriver
+
+
+def TestEventAction():
+    """App for testing event_actions."""
+    import reflex as rx
+
+    class EventActionState(rx.State):
+        order: list[str]
+
+        def on_click(self, ev):
+            self.order.append(f"on_click:{ev}")
+
+        def on_click2(self):
+            self.order.append("on_click2")
+
+        @rx.var
+        def token(self) -> str:
+            return self.get_token()
+
+    def index():
+        return rx.vstack(
+            rx.input(value=EventActionState.token, is_read_only=True, id="token"),
+            rx.button("No events", id="btn-no-events"),
+            rx.button(
+                "Stop Prop Only",
+                id="btn-stop-prop-only",
+                on_click=rx.stop_propagation,  # type: ignore
+            ),
+            rx.button(
+                "Click event",
+                on_click=EventActionState.on_click("no_event_actions"),  # type: ignore
+                id="btn-click-event",
+            ),
+            rx.button(
+                "Click stop propagation",
+                on_click=EventActionState.on_click("stop_propagation").stop_propagation,  # type: ignore
+                id="btn-click-stop-propagation",
+            ),
+            rx.button(
+                "Click stop propagation2",
+                on_click=EventActionState.on_click2.stop_propagation,
+                id="btn-click-stop-propagation2",
+            ),
+            rx.button(
+                "Click event 2",
+                on_click=EventActionState.on_click2,
+                id="btn-click-event2",
+            ),
+            rx.link(
+                "Link",
+                href="#",
+                on_click=EventActionState.on_click("link_no_event_actions"),  # type: ignore
+                id="link",
+            ),
+            rx.link(
+                "Link Stop Propagation",
+                href="#",
+                on_click=EventActionState.on_click(  # type: ignore
+                    "link_stop_propagation"
+                ).stop_propagation,
+                id="link-stop-propagation",
+            ),
+            rx.link(
+                "Link Prevent Default Only",
+                href="/invalid",
+                on_click=rx.prevent_default,  # type: ignore
+                id="link-prevent-default-only",
+            ),
+            rx.link(
+                "Link Prevent Default",
+                href="/invalid",
+                on_click=EventActionState.on_click(  # type: ignore
+                    "link_prevent_default"
+                ).prevent_default,
+                id="link-prevent-default",
+            ),
+            rx.link(
+                "Link Both",
+                href="/invalid",
+                on_click=EventActionState.on_click(  # type: ignore
+                    "link_both"
+                ).stop_propagation.prevent_default,
+                id="link-stop-propagation-prevent-default",
+            ),
+            rx.list(
+                rx.foreach(
+                    EventActionState.order,  # type: ignore
+                    rx.list_item,
+                ),
+            ),
+            on_click=EventActionState.on_click("outer"),  # type: ignore
+        )
+
+    app = rx.App(state=EventActionState)
+    app.add_page(index)
+    app.compile()
+
+
+@pytest.fixture(scope="session")
+def event_action(tmp_path_factory) -> Generator[AppHarness, None, None]:
+    """Start TestEventAction app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path_factory: pytest tmp_path_factory fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path_factory.mktemp(f"event_action"),
+        app_source=TestEventAction,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.fixture
+def driver(event_action: AppHarness) -> Generator[WebDriver, None, None]:
+    """Get an instance of the browser open to the event_action app.
+
+    Args:
+        event_action: harness for TestEventAction app
+
+    Yields:
+        WebDriver instance.
+    """
+    assert event_action.app_instance is not None, "app is not running"
+    driver = event_action.frontend()
+    try:
+        yield driver
+    finally:
+        driver.quit()
+
+
+@pytest.fixture()
+def token(event_action: AppHarness, driver: WebDriver) -> str:
+    """Get the token associated with backend state.
+
+    Args:
+        event_action: harness for TestEventAction app.
+        driver: WebDriver instance.
+
+    Returns:
+        The token visible in the driver browser.
+    """
+    assert event_action.app_instance is not None
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = event_action.poll_for_value(token_input)
+    assert token is not None
+
+    return token
+
+
+@pytest.fixture()
+def poll_for_order(
+    event_action: AppHarness, token: str
+) -> Callable[[list[str]], Coroutine[None, None, None]]:
+    """Poll for the order list to match the expected order.
+
+    Args:
+        event_action: harness for TestEventAction app.
+        token: The token visible in the driver browser.
+
+    Returns:
+        An async function that polls for the order list to match the expected order.
+    """
+
+    async def _poll_for_order(exp_order: list[str]):
+        async def _backend_state():
+            return await event_action.get_state(token)
+
+        async def _check():
+            return (await _backend_state()).order == exp_order
+
+        await AppHarness._poll_for_async(_check)
+        assert (await _backend_state()).order == exp_order
+
+    return _poll_for_order
+
+
+@pytest.mark.parametrize(
+    ("element_id", "exp_order"),
+    [
+        ("btn-no-events", ["on_click:outer"]),
+        ("btn-stop-prop-only", []),
+        ("btn-click-event", ["on_click:no_event_actions", "on_click:outer"]),
+        ("btn-click-stop-propagation", ["on_click:stop_propagation"]),
+        ("btn-click-stop-propagation2", ["on_click2"]),
+        ("btn-click-event2", ["on_click2", "on_click:outer"]),
+        ("link", ["on_click:link_no_event_actions", "on_click:outer"]),
+        ("link-stop-propagation", ["on_click:link_stop_propagation"]),
+        ("link-prevent-default", ["on_click:link_prevent_default", "on_click:outer"]),
+        ("link-prevent-default-only", ["on_click:outer"]),
+        ("link-stop-propagation-prevent-default", ["on_click:link_both"]),
+    ],
+)
+@pytest.mark.usefixtures("token")
+@pytest.mark.asyncio
+async def test_event_actions(
+    driver: WebDriver,
+    poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
+    element_id: str,
+    exp_order: list[str],
+):
+    """Click links and buttons and assert on fired events.
+
+    Args:
+        driver: WebDriver instance.
+        poll_for_order: function that polls for the order list to match the expected order.
+        element_id: The id of the element to click.
+        exp_order: The expected order of events.
+    """
+    el = driver.find_element(By.ID, element_id)
+    assert el
+
+    prev_url = driver.current_url
+
+    el.click()
+    await poll_for_order(exp_order)
+
+    if element_id.startswith("link") and "prevent-default" not in element_id:
+        assert driver.current_url != prev_url
+    else:
+        assert driver.current_url == prev_url

+ 6 - 0
integration/test_form_submit.py

@@ -53,6 +53,7 @@ def FormSubmit():
                     rx.button("Submit", type_="submit"),
                 ),
                 on_submit=FormState.form_submit,
+                custom_attrs={"action": "/invalid"},
             ),
             rx.spacer(),
             height="100vh",
@@ -145,6 +146,8 @@ async def test_submit(driver, form_submit: AppHarness):
 
     time.sleep(1)
 
+    prev_url = driver.current_url
+
     submit_input = driver.find_element(By.CLASS_NAME, "chakra-button")
     submit_input.click()
 
@@ -166,3 +169,6 @@ async def test_submit(driver, form_submit: AppHarness):
     assert form_data["select_input"] == "option1"
     assert form_data["text_area_input"] == "Some\nText"
     assert form_data["debounce_input"] == "bar baz"
+
+    # submitting the form should NOT change the url (preventDefault on_submit event)
+    assert driver.current_url == prev_url

+ 7 - 12
reflex/.templates/web/utils/state.js

@@ -486,8 +486,13 @@ export const useEventLoop = (
   const [connectError, setConnectError] = useState(null)
 
   // Function to add new events to the event queue.
-  const addEvents = (events, _e) => {
-    preventDefault(_e);
+  const addEvents = (events, _e, event_actions) => {
+    if (event_actions?.preventDefault && _e) {
+      _e.preventDefault();
+    }
+    if (event_actions?.stopPropagation && _e) {
+      _e.stopPropagation();
+    }
     queueEvents(events, socket)
   }
 
@@ -532,16 +537,6 @@ export const isTrue = (val) => {
   return Array.isArray(val) ? val.length > 0 : !!val;
 };
 
-/**
- * Prevent the default event for form submission.
- * @param event
- */
-export const preventDefault = (event) => {
-  if (event && event.type == "submit") {
-    event.preventDefault();
-  }
-};
-
 /**
  * Get the value from a ref.
  * @param ref The ref to get the value from.

+ 2 - 0
reflex/__init__.py

@@ -24,12 +24,14 @@ from .event import call_script as call_script
 from .event import clear_local_storage as clear_local_storage
 from .event import console_log as console_log
 from .event import download as download
+from .event import prevent_default as prevent_default
 from .event import redirect as redirect
 from .event import remove_cookie as remove_cookie
 from .event import remove_local_storage as remove_local_storage
 from .event import set_clipboard as set_clipboard
 from .event import set_focus as set_focus
 from .event import set_value as set_value
+from .event import stop_propagation as stop_propagation
 from .event import window_alert as window_alert
 from .middleware import Middleware as Middleware
 from .model import Model as Model

+ 20 - 11
reflex/components/component.py

@@ -242,6 +242,9 @@ class Component(Base, ABC):
             if value._var_type is not EventChain:
                 raise ValueError(f"Invalid event chain: {value}")
             return value
+        elif isinstance(value, EventChain):
+            # Trust that the caller knows what they're doing passing an EventChain directly
+            return value
 
         arg_spec = triggers.get(event_trigger, lambda: [])
 
@@ -260,7 +263,7 @@ class Component(Base, ABC):
                     deprecation_version="0.2.8",
                     removal_version="0.3.0",
                 )
-            events = []
+            events: list[EventSpec] = []
             for v in value:
                 if isinstance(v, EventHandler):
                     # Call the event handler to get the event.
@@ -291,20 +294,26 @@ class Component(Base, ABC):
             raise ValueError(f"Invalid event chain: {value}")
 
         # Add args to the event specs if necessary.
-        events = [
-            EventSpec(
-                handler=e.handler,
-                args=get_handler_args(e),
-                client_handler_name=e.client_handler_name,
-            )
-            for e in events
-        ]
+        events = [e.with_args(get_handler_args(e)) for e in events]
+
+        # Collect event_actions from each spec
+        event_actions = {}
+        for e in events:
+            event_actions.update(e.event_actions)
 
         # Return the event chain.
         if isinstance(arg_spec, Var):
-            return EventChain(events=events, args_spec=None)
+            return EventChain(
+                events=events,
+                args_spec=None,
+                event_actions=event_actions,
+            )
         else:
-            return EventChain(events=events, args_spec=arg_spec)  # type: ignore
+            return EventChain(
+                events=events,
+                args_spec=arg_spec,  # type: ignore
+                event_actions=event_actions,
+            )
 
     def get_event_triggers(self) -> Dict[str, Any]:
         """Get the event triggers for the component.

+ 26 - 1
reflex/components/forms/form.py

@@ -1,10 +1,12 @@
 """Form components."""
+from __future__ import annotations
 
-from typing import Any, Dict
+from typing import Any, Callable, Dict, List
 
 from reflex.components.component import Component
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.constants import EventTriggers
+from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.vars import Var
 
 
@@ -16,6 +18,29 @@ class Form(ChakraComponent):
     # What the form renders to.
     as_: Var[str] = "form"  # type: ignore
 
+    def _create_event_chain(
+        self,
+        event_trigger: str,
+        value: Var
+        | EventHandler
+        | EventSpec
+        | List[EventHandler | EventSpec]
+        | Callable[..., Any],
+    ) -> EventChain | Var:
+        """Override the event chain creation to preventDefault for on_submit.
+
+        Args:
+            event_trigger: The event trigger.
+            value: The value of the event trigger.
+
+        Returns:
+            The event chain.
+        """
+        chain = super()._create_event_chain(event_trigger, value)
+        if event_trigger == EventTriggers.ON_SUBMIT and isinstance(chain, EventChain):
+            return chain.prevent_default
+        return chain
+
     def get_event_triggers(self) -> Dict[str, Any]:
         """Get the event triggers that pass the component's value to the handler.
 

+ 57 - 4
reflex/event.py

@@ -101,7 +101,36 @@ def _no_chain_background_task(
     raise TypeError(f"{fn} is marked as a background task, but is not async.")
 
 
-class EventHandler(Base):
+class EventActionsMixin(Base):
+    """Mixin for DOM event actions."""
+
+    # Whether to `preventDefault` or `stopPropagation` on the event.
+    event_actions: Dict[str, bool] = {}
+
+    @property
+    def stop_propagation(self):
+        """Stop the event from bubbling up the DOM tree.
+
+        Returns:
+            New EventHandler-like with stopPropagation set to True.
+        """
+        return self.copy(
+            update={"event_actions": {"stopPropagation": True, **self.event_actions}},
+        )
+
+    @property
+    def prevent_default(self):
+        """Prevent the default behavior of the event.
+
+        Returns:
+            New EventHandler-like with preventDefault set to True.
+        """
+        return self.copy(
+            update={"event_actions": {"preventDefault": True, **self.event_actions}},
+        )
+
+
+class EventHandler(EventActionsMixin):
     """An event handler responds to an event to update the state."""
 
     # The function to call in response to the event.
@@ -150,6 +179,7 @@ class EventHandler(Base):
                     client_handler_name="uploadFiles",
                     # `files` is defined in the Upload component's _use_hooks
                     args=((Var.create_safe("files"), Var.create_safe("files")),),
+                    event_actions=self.event_actions.copy(),
                 )
 
             # Otherwise, convert to JSON.
@@ -162,10 +192,12 @@ class EventHandler(Base):
         payload = tuple(zip(fn_args, values))
 
         # Return the event spec.
-        return EventSpec(handler=self, args=payload)
+        return EventSpec(
+            handler=self, args=payload, event_actions=self.event_actions.copy()
+        )
 
 
-class EventSpec(Base):
+class EventSpec(EventActionsMixin):
     """An event specification.
 
     Whereas an Event object is passed during runtime, a spec is used
@@ -187,8 +219,24 @@ class EventSpec(Base):
         # Required to allow tuple fields.
         frozen = True
 
+    def with_args(self, args: Tuple[Tuple[Var, Var], ...]) -> EventSpec:
+        """Copy the event spec, with updated args.
+
+        Args:
+            args: The new args to pass to the function.
+
+        Returns:
+            A copy of the event spec, with the new args.
+        """
+        return type(self)(
+            handler=self.handler,
+            client_handler_name=self.client_handler_name,
+            args=args,
+            event_actions=self.event_actions.copy(),
+        )
 
-class EventChain(Base):
+
+class EventChain(EventActionsMixin):
     """Container for a chain of events that will be executed in order."""
 
     events: List[EventSpec]
@@ -196,6 +244,11 @@ class EventChain(Base):
     args_spec: Optional[Callable]
 
 
+# These chains can be used for their side effects when no other events are desired.
+stop_propagation = EventChain(events=[], args_spec=lambda: []).stop_propagation
+prevent_default = EventChain(events=[], args_spec=lambda: []).prevent_default
+
+
 class Target(Base):
     """A Javascript event target."""
 

+ 1 - 1
reflex/utils/format.py

@@ -314,7 +314,7 @@ def format_prop(
                 arg_def = "(_e)"
 
             chain = ",".join([format_event(event) for event in prop.events])
-            event = f"addEvents([{chain}], {arg_def})"
+            event = f"addEvents([{chain}], {arg_def}, {json_dumps(prop.event_actions)})"
             prop = f"{arg_def} => {event}"
 
         # Handle other types.

+ 3 - 3
tests/components/base/test_script.py

@@ -57,14 +57,14 @@ def test_script_event_handler():
     )
     render_dict = component.render()
     assert (
-        'onReady={(_e) => addEvents([Event("ev_state.on_ready", {})], (_e))}'
+        'onReady={(_e) => addEvents([Event("ev_state.on_ready", {})], (_e), {})}'
         in render_dict["props"]
     )
     assert (
-        'onLoad={(_e) => addEvents([Event("ev_state.on_load", {})], (_e))}'
+        'onLoad={(_e) => addEvents([Event("ev_state.on_load", {})], (_e), {})}'
         in render_dict["props"]
     )
     assert (
-        'onError={(_e) => addEvents([Event("ev_state.on_error", {})], (_e))}'
+        'onError={(_e) => addEvents([Event("ev_state.on_error", {})], (_e), {})}'
         in render_dict["props"]
     )

+ 1 - 1
tests/components/test_component.py

@@ -425,7 +425,7 @@ def test_component_event_trigger_arbitrary_args():
     assert comp.render()["props"][0] == (
         "onFoo={(__e,_alpha,_bravo,_charlie) => addEvents("
         '[Event("c1_state.mock_handler", {_e:__e.target.value,_bravo:_bravo["nested"],_charlie:(_charlie.custom + 42)})], '
-        "(__e,_alpha,_bravo,_charlie))}"
+        "(__e,_alpha,_bravo,_charlie), {})}"
     )
 
 

+ 52 - 0
tests/test_event.py

@@ -4,6 +4,7 @@ import pytest
 
 from reflex import event
 from reflex.event import Event, EventHandler, EventSpec, fix_events
+from reflex.state import State
 from reflex.utils import format
 from reflex.vars import Var
 
@@ -261,3 +262,54 @@ def test_remove_local_storage():
     assert (
         format.format_event(spec) == 'Event("_remove_local_storage", {key:`testkey`})'
     )
+
+
+def test_event_actions():
+    """Test DOM event actions, like stopPropagation and preventDefault."""
+    # EventHandler
+    handler = EventHandler(fn=lambda: None)
+    assert not handler.event_actions
+    sp_handler = handler.stop_propagation
+    assert handler is not sp_handler
+    assert sp_handler.event_actions == {"stopPropagation": True}
+    pd_handler = handler.prevent_default
+    assert handler is not pd_handler
+    assert pd_handler.event_actions == {"preventDefault": True}
+    both_handler = sp_handler.prevent_default
+    assert both_handler is not sp_handler
+    assert both_handler.event_actions == {
+        "stopPropagation": True,
+        "preventDefault": True,
+    }
+    assert not handler.event_actions
+
+    # Convert to EventSpec should carry event actions
+    sp_handler2 = handler.stop_propagation
+    spec = sp_handler2()
+    assert spec.event_actions == {"stopPropagation": True}
+    assert spec.event_actions == sp_handler2.event_actions
+    assert spec.event_actions is not sp_handler2.event_actions
+    # But it should be a copy!
+    assert spec.event_actions is not sp_handler2.event_actions
+    spec2 = spec.prevent_default
+    assert spec is not spec2
+    assert spec2.event_actions == {"stopPropagation": True, "preventDefault": True}
+    assert spec2.event_actions != spec.event_actions
+
+    # The original handler should still not be touched.
+    assert not handler.event_actions
+
+
+def test_event_actions_on_state():
+    class EventActionState(State):
+        def handler(self):
+            pass
+
+    handler = EventActionState.handler
+    assert isinstance(handler, EventHandler)
+    assert not handler.event_actions
+
+    sp_handler = EventActionState.handler.stop_propagation
+    assert sp_handler.event_actions == {"stopPropagation": True}
+    # should NOT affect other references to the handler
+    assert not handler.event_actions

+ 44 - 1
tests/utils/test_format.py

@@ -4,7 +4,7 @@ from typing import Any
 import pytest
 
 from reflex.components.tags.tag import Tag
-from reflex.event import EventChain, EventHandler, EventSpec
+from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent
 from reflex.style import Style
 from reflex.utils import format
 from reflex.vars import BaseVar, Var
@@ -290,6 +290,49 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected
             },
             r'{{"a": "foo \"{ \"bar\" }\" baz", "b": val}}',
         ),
+        (
+            EventChain(
+                events=[EventSpec(handler=EventHandler(fn=mock_event))],
+                args_spec=lambda: [],
+            ),
+            '{(_e) => addEvents([Event("mock_event", {})], (_e), {})}',
+        ),
+        (
+            EventChain(
+                events=[
+                    EventSpec(
+                        handler=EventHandler(fn=mock_event),
+                        args=(
+                            (
+                                Var.create_safe("arg"),
+                                BaseVar(
+                                    _var_name="_e",
+                                    _var_type=FrontendEvent,
+                                ).target.value,
+                            ),
+                        ),
+                    )
+                ],
+                args_spec=lambda: [],
+            ),
+            '{(_e) => addEvents([Event("mock_event", {arg:_e.target.value})], (_e), {})}',
+        ),
+        (
+            EventChain(
+                events=[EventSpec(handler=EventHandler(fn=mock_event))],
+                args_spec=lambda: [],
+                event_actions={"stopPropagation": True},
+            ),
+            '{(_e) => addEvents([Event("mock_event", {})], (_e), {"stopPropagation": true})}',
+        ),
+        (
+            EventChain(
+                events=[EventSpec(handler=EventHandler(fn=mock_event))],
+                args_spec=lambda: [],
+                event_actions={"preventDefault": True},
+            ),
+            '{(_e) => addEvents([Event("mock_event", {})], (_e), {"preventDefault": true})}',
+        ),
         ({"a": "red", "b": "blue"}, '{{"a": "red", "b": "blue"}}'),
         (BaseVar(_var_name="var", _var_type="int"), "{var}"),
         (