浏览代码

state: _init_event_handlers recursively (#1640)

Masen Furer 1 年之前
父节点
当前提交
12e516da64
共有 2 个文件被更改,包括 32 次插入5 次删除
  1. 24 5
      reflex/state.py
  2. 8 0
      tests/test_state.py

+ 24 - 5
reflex/state.py

@@ -106,11 +106,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         for substate in self.get_substates():
             self.substates[substate.get_name()] = substate(parent_state=self)
         # Convert the event handlers to functions.
-        for name, event_handler in self.event_handlers.items():
-            fn = functools.partial(event_handler.fn, self)
-            fn.__module__ = event_handler.fn.__module__  # type: ignore
-            fn.__qualname__ = event_handler.fn.__qualname__  # type: ignore
-            setattr(self, name, fn)
+        self._init_event_handlers()
 
         # Initialize computed vars dependencies.
         inherited_vars = set(self.inherited_vars).union(
@@ -155,6 +151,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
         self._clean()
 
+    def _init_event_handlers(self, state: State | None = None):
+        """Initialize event handlers.
+
+        Allow event handlers to be called directly on the instance. This is
+        called recursively for all parent states.
+
+        Args:
+            state: The state to initialize the event handlers on.
+        """
+        if state is None:
+            state = self
+
+        # Convert the event handlers to functions.
+        for name, event_handler in state.event_handlers.items():
+            fn = functools.partial(event_handler.fn, self)
+            fn.__module__ = event_handler.fn.__module__  # type: ignore
+            fn.__qualname__ = event_handler.fn.__qualname__  # type: ignore
+            setattr(self, name, fn)
+
+        # Also allow direct calling of parent state event handlers
+        if state.parent_state is not None:
+            self._init_event_handlers(state.parent_state)
+
     def _reassign_field(self, field_name: str):
         """Reassign the given field.
 

+ 8 - 0
tests/test_state.py

@@ -992,10 +992,18 @@ def test_event_handlers_call_other_handlers():
         def set_v2(self, v: int):
             self.set_v(v)
 
+    class SubState(MainState):
+        def set_v3(self, v: int):
+            self.set_v2(v)
+
     ms = MainState()
     ms.set_v2(1)
     assert ms.v == 1
 
+    # ensure handler can be called from substate
+    ms.substates[SubState.get_name()].set_v3(2)
+    assert ms.v == 2
+
 
 def test_computed_var_cached():
     """Test that a ComputedVar doesn't recalculate when accessed."""