瀏覽代碼

improve event handler state references (#2818)

benedikt-bartscher 1 年之前
父節點
當前提交
8a3c9383fb
共有 3 個文件被更改,包括 33 次插入15 次删除
  1. 4 0
      reflex/event.py
  2. 20 5
      reflex/state.py
  3. 9 10
      reflex/utils/format.py

+ 4 - 0
reflex/event.py

@@ -147,6 +147,10 @@ class EventHandler(EventActionsMixin):
     # The function to call in response to the event.
     fn: Any
 
+    # The full name of the state class this event handler is attached to.
+    # Emtpy string means this event handler is a server side event.
+    state_full_name: str = ""
+
     class Config:
         """The Pydantic config."""
 

+ 20 - 5
reflex/state.py

@@ -472,7 +472,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 events[name] = value
 
         for name, fn in events.items():
-            handler = EventHandler(fn=fn)
+            handler = cls._create_event_handler(fn)
             cls.event_handlers[name] = handler
             setattr(cls, name, handler)
 
@@ -677,7 +677,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
     @classmethod
     @functools.lru_cache()
-    def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
+    def get_class_substate(cls, path: Sequence[str] | str) -> Type[BaseState]:
         """Get the class substate.
 
         Args:
@@ -689,6 +689,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         Raises:
             ValueError: If the substate is not found.
         """
+        if isinstance(path, str):
+            path = tuple(path.split("."))
+
         if len(path) == 0:
             return cls
         if path[0] == cls.get_name():
@@ -789,6 +792,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         setattr(cls, prop._var_name, prop)
 
+    @classmethod
+    def _create_event_handler(cls, fn):
+        """Create an event handler for the given function.
+
+        Args:
+            fn: The function to create an event handler for.
+
+        Returns:
+            The event handler.
+        """
+        return EventHandler(fn=fn, state_full_name=cls.get_full_name())
+
     @classmethod
     def _create_setter(cls, prop: BaseVar):
         """Create a setter for the var.
@@ -798,7 +813,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         setter_name = prop.get_setter_name(include_state=False)
         if setter_name not in cls.__dict__:
-            event_handler = EventHandler(fn=prop.get_setter())
+            event_handler = cls._create_event_handler(prop.get_setter())
             cls.event_handlers[setter_name] = event_handler
             setattr(cls, setter_name, event_handler)
 
@@ -1752,7 +1767,7 @@ class UpdateVarsInternalState(State):
         """
         for var, value in vars.items():
             state_name, _, var_name = var.rpartition(".")
-            var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
+            var_state_cls = State.get_class_substate(state_name)
             var_state = await self.get_state(var_state_cls)
             setattr(var_state, var_name, value)
 
@@ -2268,7 +2283,7 @@ class StateManagerRedis(StateManager):
         _, state_path = _split_substate_key(token)
         if state_path:
             # Get the State class associated with the given path.
-            state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
+            state_cls = self.state.get_class_substate(state_path)
         else:
             raise RuntimeError(
                 "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"

+ 9 - 10
reflex/utils/format.py

@@ -6,7 +6,6 @@ import inspect
 import json
 import os
 import re
-import sys
 from typing import TYPE_CHECKING, Any, List, Union
 
 from reflex import constants
@@ -470,18 +469,18 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
     if len(parts) == 1:
         return ("", parts[-1])
 
-    # Get the state and the function name.
-    state_name, name = parts[-2:]
+    # Get the state full name
+    state_full_name = handler.state_full_name
 
-    # Construct the full event handler name.
-    try:
-        # Try to get the state from the module.
-        state = vars(sys.modules[handler.fn.__module__])[state_name]
-    except Exception:
-        # If the state isn't in the module, just return the function name.
+    # Get the function name
+    name = parts[-1]
+
+    from reflex.state import State
+
+    if state_full_name == "state" and name not in State.__dict__:
         return ("", to_snake_case(handler.fn.__qualname__))
 
-    return (state.get_full_name(), name)
+    return (state_full_name, name)
 
 
 def format_event_handler(handler: EventHandler) -> str: