Prechádzať zdrojové kódy

Implement `on_mount` and `on_unmount` for all components. (#1636)

Masen Furer 1 rok pred
rodič
commit
2392c52928

+ 93 - 0
integration/test_event_chain.py

@@ -155,8 +155,35 @@ def EventChain():
             rx.input(value=State.token, readonly=True, id="token"),
         )
 
+    def on_mount_return_chain():
+        return rx.fragment(
+            rx.text(
+                "return",
+                on_mount=State.on_load_return_chain,
+                on_unmount=lambda: State.event_arg("unmount"),  # type: ignore
+            ),
+            rx.input(value=State.token, readonly=True, id="token"),
+            rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
+        )
+
+    def on_mount_yield_chain():
+        return rx.fragment(
+            rx.text(
+                "yield",
+                on_mount=[
+                    State.on_load_yield_chain,
+                    lambda: State.event_arg("mount"),  # type: ignore
+                ],
+                on_unmount=State.event_no_args,
+            ),
+            rx.input(value=State.token, readonly=True, id="token"),
+            rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
+        )
+
     app.add_page(on_load_return_chain, on_load=State.on_load_return_chain)  # type: ignore
     app.add_page(on_load_yield_chain, on_load=State.on_load_yield_chain)  # type: ignore
+    app.add_page(on_mount_return_chain)
+    app.add_page(on_mount_yield_chain)
 
     app.compile()
 
@@ -330,3 +357,69 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
     time.sleep(0.5)
     backend_state = event_chain.app_instance.state_manager.states[token]
     assert backend_state.event_order == exp_event_order
+
+
+@pytest.mark.parametrize(
+    ("uri", "exp_event_order"),
+    [
+        (
+            "/on-mount-return-chain",
+            [
+                "on_load_return_chain",
+                "event_arg:unmount",
+                "on_load_return_chain",
+                "event_arg:1",
+                "event_arg:2",
+                "event_arg:3",
+                "event_arg:1",
+                "event_arg:2",
+                "event_arg:3",
+                "event_arg:unmount",
+            ],
+        ),
+        (
+            "/on-mount-yield-chain",
+            [
+                "on_load_yield_chain",
+                "event_arg:mount",
+                "event_no_args",
+                "on_load_yield_chain",
+                "event_arg:mount",
+                "event_arg:4",
+                "event_arg:5",
+                "event_arg:6",
+                "event_arg:4",
+                "event_arg:5",
+                "event_arg:6",
+                "event_no_args",
+            ],
+        ),
+    ],
+)
+def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
+    """Load the URI, assert that the events are handled in the correct order.
+
+    These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
+    due to react StrictMode being used.
+
+    In prod mode, these events are only fired once.
+
+    Args:
+        event_chain: AppHarness for the event_chain app
+        driver: selenium WebDriver open to the app
+        uri: the page to load
+        exp_event_order: the expected events recorded in the State
+    """
+    driver.get(event_chain.frontend_url + uri)
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    token = event_chain.poll_for_value(token_input)
+
+    unmount_button = driver.find_element(By.ID, "unmount")
+    assert unmount_button
+    unmount_button.click()
+
+    time.sleep(1)
+    backend_state = event_chain.app_instance.state_manager.states[token]
+    assert backend_state.event_order == exp_event_order

+ 5 - 0
reflex/.templates/web/utils/state.js

@@ -218,6 +218,11 @@ export const queueEvents = async (events, socket) => {
 export const processEvent = async (
   socket
 ) => {
+  // Only proceed if the socket is up, otherwise we throw the event into the void
+  if (!socket) {
+    return;
+  }
+
   // Only proceed if we're not already processing an event.
   if (event_queue.length === 0 || event_processing) {
     return;

+ 57 - 6
reflex/components/component.py

@@ -286,7 +286,11 @@ class Component(Base, ABC):
         Returns:
             The event triggers.
         """
-        return EVENT_TRIGGERS | set(self.get_controlled_triggers())
+        return (
+            EVENT_TRIGGERS
+            | set(self.get_controlled_triggers())
+            | set((constants.ON_MOUNT, constants.ON_UNMOUNT))
+        )
 
     def get_controlled_triggers(self) -> Dict[str, Var]:
         """Get the event triggers that pass the component's value to the handler.
@@ -525,16 +529,63 @@ class Component(Base, ABC):
             self._get_imports(), *[child.get_imports() for child in self.children]
         )
 
-    def _get_hooks(self) -> Optional[str]:
-        """Get the React hooks for this component.
+    def _get_mount_lifecycle_hook(self) -> str | None:
+        """Generate the component lifecycle hook.
 
         Returns:
-            The hooks for just this component.
+            The useEffect hook for managing `on_mount` and `on_unmount` events.
+        """
+        # pop on_mount and on_unmount from event_triggers since these are handled by
+        # hooks, not as actually props in the component
+        on_mount = self.event_triggers.pop(constants.ON_MOUNT, None)
+        on_unmount = self.event_triggers.pop(constants.ON_UNMOUNT, None)
+        if on_mount:
+            on_mount = format.format_event_chain(on_mount)
+        if on_unmount:
+            on_unmount = format.format_event_chain(on_unmount)
+        if on_mount or on_unmount:
+            return f"""
+                useEffect(() => {{
+                    {on_mount or ""}
+                    return () => {{
+                        {on_unmount or ""}
+                    }}
+                }}, []);"""
+
+    def _get_ref_hook(self) -> str | None:
+        """Generate the ref hook for the component.
+
+        Returns:
+            The useRef hook for managing refs.
         """
         ref = self.get_ref()
         if ref is not None:
             return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
-        return None
+
+    def _get_hooks_internal(self) -> Set[str]:
+        """Get the React hooks for this component managed by the framework.
+
+        Downstream components should NOT override this method to avoid breaking
+        framework functionality.
+
+        Returns:
+            Set of internally managed hooks.
+        """
+        return set(
+            hook
+            for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
+            if hook
+        )
+
+    def _get_hooks(self) -> Optional[str]:
+        """Get the React hooks for this component.
+
+        Downstream components should override this method to add their own hooks.
+
+        Returns:
+            The hooks for just this component.
+        """
+        return
 
     def get_hooks(self) -> Set[str]:
         """Get the React hooks for this component and its children.
@@ -543,7 +594,7 @@ class Component(Base, ABC):
             The code that should appear just before returning the rendered component.
         """
         # Store the code in a set to avoid duplicates.
-        code = set()
+        code = self._get_hooks_internal()
 
         # Add the hook code for this component.
         hooks = self._get_hooks()

+ 4 - 4
reflex/components/forms/pininput.py

@@ -76,8 +76,8 @@ class PinInput(ChakraComponent):
         """
         return None
 
-    def _get_hooks(self) -> Optional[str]:
-        """Override the base get_hooks to handle array refs.
+    def _get_ref_hook(self) -> Optional[str]:
+        """Override the base _get_ref_hook to handle array refs.
 
         Returns:
             The overrided hooks.
@@ -86,7 +86,7 @@ class PinInput(ChakraComponent):
             ref = format.format_array_ref(self.id, None)
             if ref:
                 return f"const {ref} = Array.from({{length:{self.length}}}, () => useRef(null));"
-            return super()._get_hooks()
+            return super()._get_ref_hook()
 
     @classmethod
     def create(cls, *children, **props) -> Component:
@@ -130,7 +130,7 @@ class PinInputField(ChakraComponent):
     # Default to None because it is assigned by PinInput when created.
     index: Optional[Var[int]] = None
 
-    def _get_hooks(self) -> Optional[str]:
+    def _get_ref_hook(self) -> Optional[str]:
         return None
 
     def get_ref(self):

+ 4 - 4
reflex/components/forms/rangeslider.py

@@ -64,8 +64,8 @@ class RangeSlider(ChakraComponent):
         """
         return None
 
-    def _get_hooks(self) -> Optional[str]:
-        """Override the base get_hooks to handle array refs.
+    def _get_ref_hook(self) -> Optional[str]:
+        """Override the base _get_ref_hook to handle array refs.
 
         Returns:
             The overrided hooks.
@@ -74,7 +74,7 @@ class RangeSlider(ChakraComponent):
             ref = format.format_array_ref(self.id, None)
             if ref:
                 return f"const {ref} = Array.from({{length:2}}, () => useRef(null));"
-            return super()._get_hooks()
+            return super()._get_ref_hook()
 
     @classmethod
     def create(cls, *children, **props) -> Component:
@@ -130,7 +130,7 @@ class RangeSliderThumb(ChakraComponent):
     # The position of the thumb.
     index: Var[int]
 
-    def _get_hooks(self) -> Optional[str]:
+    def _get_ref_hook(self) -> Optional[str]:
         # hook is None because RangeSlider is handling it.
         return None
 

+ 4 - 0
reflex/constants.py

@@ -359,5 +359,9 @@ PING_TIMEOUT = 120
 # Alembic migrations
 ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini")
 
+# Names of event handlers on all components mapped to useEffect
+ON_MOUNT = "on_mount"
+ON_UNMOUNT = "on_unmount"
+
 # If this env var is set to "yes", App.compile will be a no-op
 SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"

+ 60 - 1
reflex/utils/format.py

@@ -16,10 +16,11 @@ from plotly.io import to_json
 
 from reflex import constants
 from reflex.utils import types
+from reflex.vars import Var
 
 if TYPE_CHECKING:
     from reflex.components.component import ComponentStyle
-    from reflex.event import EventHandler, EventSpec
+    from reflex.event import EventChain, EventHandler, EventSpec
 
 WRAP_MAP = {
     "{": "}",
@@ -182,6 +183,24 @@ def format_string(string: str) -> str:
     return string
 
 
+def format_var(var: Var) -> str:
+    """Format the given Var as a javascript value.
+
+    Args:
+        var: The Var to format.
+
+    Returns:
+        The formatted Var.
+    """
+    if not var.is_local or var.is_string:
+        return str(var)
+    if types._issubclass(var.type_, str):
+        return format_string(var.full_name)
+    if is_wrapped(var.full_name, "{"):
+        return var.full_name
+    return json_dumps(var.full_name)
+
+
 def format_route(route: str) -> str:
     """Format the given route.
 
@@ -311,6 +330,46 @@ def format_event(event_spec: EventSpec) -> str:
     return f"E({', '.join(event_args)})"
 
 
+def format_event_chain(
+    event_chain: EventChain | Var[EventChain],
+    event_arg: Var | None = None,
+) -> str:
+    """Format an event chain as a javascript invocation.
+
+    Args:
+        event_chain: The event chain to queue on the frontend.
+        event_arg: The browser-native event (only used to preventDefault).
+
+    Returns:
+        Compiled javascript code to queue the given event chain on the frontend.
+
+    Raises:
+        ValueError: When the given event chain is not a valid event chain.
+    """
+    if isinstance(event_chain, Var):
+        from reflex.event import EventChain
+
+        if event_chain.type_ is not EventChain:
+            raise ValueError(f"Invalid event chain: {event_chain}")
+        return "".join(
+            [
+                "(() => {",
+                format_var(event_chain),
+                f"; preventDefault({format_var(event_arg)})" if event_arg else "",
+                "})()",
+            ]
+        )
+
+    chain = ",".join([format_event(event) for event in event_chain.events])
+    return "".join(
+        [
+            f"Event([{chain}]",
+            f", {format_var(event_arg)}" if event_arg else "",
+            ")",
+        ]
+    )
+
+
 def format_query_params(router_data: Dict[str, Any]) -> Dict[str, str]:
     """Convert back query params name to python-friendly case.
 

+ 4 - 2
tests/components/test_component.py

@@ -5,6 +5,7 @@ import pytest
 import reflex as rx
 from reflex.components.component import Component, CustomComponent, custom_component
 from reflex.components.layout.box import Box
+from reflex.constants import ON_MOUNT, ON_UNMOUNT
 from reflex.event import EVENT_ARG, EVENT_TRIGGERS, EventHandler
 from reflex.state import State
 from reflex.style import Style
@@ -377,8 +378,9 @@ def test_get_triggers(component1, component2):
         component1: A test component.
         component2: A test component.
     """
-    assert component1().get_triggers() == EVENT_TRIGGERS
-    assert component2().get_triggers() == {"on_open", "on_close"} | EVENT_TRIGGERS
+    default_triggers = {ON_MOUNT, ON_UNMOUNT} | EVENT_TRIGGERS
+    assert component1().get_triggers() == default_triggers
+    assert component2().get_triggers() == {"on_open", "on_close"} | default_triggers
 
 
 def test_create_custom_component(my_component):