Просмотр исходного кода

Proper serialization for chained Event payloads (#1725)

Masen Furer 1 год назад
Родитель
Сommit
cc89f2b6e7

+ 48 - 0
integration/test_event_chain.py

@@ -31,6 +31,9 @@ def EventChain():
         def event_arg(self, arg):
             self.event_order.append(f"event_arg:{arg}")
 
+        def event_arg_repr_type(self, arg):
+            self.event_order.append(f"event_arg_repr:{arg!r}_{type(arg).__name__}")
+
         def event_nested_1(self):
             self.event_order.append("event_nested_1")
             yield State.event_nested_2
@@ -100,6 +103,14 @@ def EventChain():
             self.event_order.append("redirect_yield_chain")
             yield rx.redirect("/on-load-yield-chain")
 
+        def click_return_int_type(self):
+            self.event_order.append("click_return_int_type")
+            return State.event_arg_repr_type(1)  # type: ignore
+
+        def click_return_dict_type(self):
+            self.event_order.append("click_return_dict_type")
+            return State.event_arg_repr_type({"a": 1})  # type: ignore
+
     app = rx.App(state=State)
 
     @app.add_page
@@ -141,6 +152,26 @@ def EventChain():
                 id="redirect_return_chain",
                 on_click=State.redirect_return_chain,
             ),
+            rx.button(
+                "Click Int Type",
+                id="click_int_type",
+                on_click=lambda: State.event_arg_repr_type(1),  # type: ignore
+            ),
+            rx.button(
+                "Click Dict Type",
+                id="click_dict_type",
+                on_click=lambda: State.event_arg_repr_type({"a": 1}),  # type: ignore
+            ),
+            rx.button(
+                "Return Chain Int Type",
+                id="return_int_type",
+                on_click=State.click_return_int_type,
+            ),
+            rx.button(
+                "Return Chain Dict Type",
+                id="return_dict_type",
+                on_click=State.click_return_dict_type,
+            ),
         )
 
     def on_load_return_chain():
@@ -286,6 +317,22 @@ def driver(event_chain: AppHarness):
                 "event_arg:6",
             ],
         ),
+        (
+            "click_int_type",
+            ["event_arg_repr:1_int"],
+        ),
+        (
+            "click_dict_type",
+            ["event_arg_repr:{'a': 1}_dict"],
+        ),
+        (
+            "return_int_type",
+            ["click_return_int_type", "event_arg_repr:1_int"],
+        ),
+        (
+            "return_dict_type",
+            ["click_return_dict_type", "event_arg_repr:{'a': 1}_dict"],
+        ),
     ],
 )
 def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
@@ -356,6 +403,7 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
 
     time.sleep(0.5)
     backend_state = event_chain.app_instance.state_manager.states[token]
+    assert backend_state.is_hydrated is True
     assert backend_state.event_order == exp_event_order
 
 

+ 1 - 1
reflex/event.py

@@ -479,7 +479,7 @@ def fix_events(
             e = e()
         assert isinstance(e, EventSpec), f"Unexpected event type, {type(e)}."
         name = format.format_event_handler(e.handler)
-        payload = {k.name: v.name for k, v in e.args}
+        payload = {k.name: v._decode() for k, v in e.args}
 
         # Create an event and append it to the list.
         out.append(

+ 18 - 0
reflex/vars.py

@@ -144,6 +144,24 @@ class Var(ABC):
         """
         return _GenericAlias(cls, type_)
 
+    def _decode(self) -> Any:
+        """Decode Var as a python value.
+
+        Note that Var with state set cannot be decoded python-side and will be
+        returned as full_name.
+
+        Returns:
+            The decoded value or the Var name.
+        """
+        if self.state:
+            return self.full_name
+        if self.is_string or self.type_ is Figure:
+            return self.name
+        try:
+            return json.loads(self.name)
+        except ValueError:
+            return self.name
+
     def equals(self, other: Var) -> bool:
         """Check if two vars are equal.
 

+ 1 - 1
tests/middleware/test_hydrate_middleware.py

@@ -17,7 +17,7 @@ def exp_is_hydrated(state: State) -> Dict[str, Any]:
     Returns:
         dict similar to that returned by `State.get_delta` with IS_HYDRATED: True
     """
-    return {state.get_name(): {IS_HYDRATED: "true"}}
+    return {state.get_name(): {IS_HYDRATED: True}}
 
 
 class TestState(State):

+ 1 - 1
tests/test_app.py

@@ -833,7 +833,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                 _dynamic_state_event(name="on_load", val=exp_val, router_data={}),
                 _dynamic_state_event(
                     name="set_is_hydrated",
-                    payload={"value": "true"},
+                    payload={"value": True},
                     val=exp_val,
                     router_data={},
                 ),

+ 30 - 1
tests/test_event.py

@@ -3,7 +3,7 @@ import json
 import pytest
 
 from reflex import event
-from reflex.event import Event, EventHandler, EventSpec
+from reflex.event import Event, EventHandler, EventSpec, fix_events
 from reflex.utils import format
 from reflex.vars import Var
 
@@ -87,6 +87,35 @@ def test_call_event_handler():
         handler(test_fn)  # type: ignore
 
 
+@pytest.mark.parametrize(
+    ("arg1", "arg2"),
+    (
+        (1, 2),
+        (1, "2"),
+        ({"a": 1}, {"b": 2}),
+    ),
+)
+def test_fix_events(arg1, arg2):
+    """Test that chaining an event handler with args formats the payload correctly.
+
+    Args:
+        arg1: The first arg passed to the handler.
+        arg2: The second arg passed to the handler.
+    """
+
+    def test_fn_with_args(_, arg1, arg2):
+        pass
+
+    test_fn_with_args.__qualname__ = "test_fn_with_args"
+
+    handler = EventHandler(fn=test_fn_with_args)
+    event_spec = handler(arg1, arg2)
+    event = fix_events([event_spec], token="foo")[0]
+    assert event.name == test_fn_with_args.__qualname__
+    assert event.token == "foo"
+    assert event.payload == {"arg1": arg1, "arg2": arg2}
+
+
 def test_event_redirect():
     """Test the event redirect function."""
     spec = event.redirect("/path")