浏览代码

LiteralEventChainVar becomes an ArgsFunctionOperation (#4174)

* LiteralEventChainVar becomes an ArgsFunctionOperation

Instead of using the ArgsFunctionOperation to create the string representation
of the _js_expr, make the identity of the var an ArgsFunctionOperation so the
_args_names and _return_expr remain accessible.

Rely on the default behavior of ArgsFunctionOperation to create the
_cached_var_name / _js_expr value.

This allows the compat shim in `format_event_chain` to remain functional, as it
does special handling for ArgsFunctionOperation to retain the previous behavior
of that function (this was a regression introduced in 0.6.2).

* _var_type is EventChain; fix parent class order

* Re-fix LiteralEventChainVar inheritence list w/ comment

* [ENG-3942] LiteralEventVar becomes VarCallOperation

instead of using `.call` when constructing the `_js_expr`, have the identity of
a LiteralEventVar as a VarCallOperation to take advantage of the _var_data
carrying.

* add event overlords

* EventCallback descriptor always returns EventSpec from class

Relax actual `__get__` definition to support the multitude of overloads

* test case for event related vars carrying _var_data

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
Masen Furer 7 月之前
父节点
当前提交
4422515f4b
共有 2 个文件被更改,包括 138 次插入73 次删除
  1. 103 70
      reflex/event.py
  2. 35 3
      tests/units/test_event.py

+ 103 - 70
reflex/event.py

@@ -22,6 +22,7 @@ from typing import (
     TypeVar,
     Union,
     get_type_hints,
+    overload,
 )
 
 from typing_extensions import ParamSpec, get_args, get_origin
@@ -32,14 +33,17 @@ from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
 from reflex.utils.types import ArgsSpec, GenericType
 from reflex.vars import VarData
 from reflex.vars.base import (
-    CachedVarOperation,
     LiteralNoneVar,
     LiteralVar,
     ToOperation,
     Var,
-    cached_property_no_lock,
 )
-from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar
+from reflex.vars.function import (
+    ArgsFunctionOperation,
+    FunctionStringVar,
+    FunctionVar,
+    VarOperationCall,
+)
 from reflex.vars.object import ObjectVar
 
 try:
@@ -1258,7 +1262,7 @@ class EventVar(ObjectVar):
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
-class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
+class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
     """A literal event var."""
 
     _var_value: EventSpec = dataclasses.field(default=None)  # type: ignore
@@ -1271,35 +1275,6 @@ class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
         """
         return hash((self.__class__.__name__, self._js_expr))
 
-    @cached_property_no_lock
-    def _cached_var_name(self) -> str:
-        """The name of the var.
-
-        Returns:
-            The name of the var.
-        """
-        return str(
-            FunctionStringVar("Event").call(
-                # event handler name
-                ".".join(
-                    filter(
-                        None,
-                        format.get_event_handler_parts(self._var_value.handler),
-                    )
-                ),
-                # event handler args
-                {str(name): value for name, value in self._var_value.args},
-                # event actions
-                self._var_value.event_actions,
-                # client handler name
-                *(
-                    [self._var_value.client_handler_name]
-                    if self._var_value.client_handler_name
-                    else []
-                ),
-            )
-        )
-
     @classmethod
     def create(
         cls,
@@ -1320,6 +1295,22 @@ class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
             _var_type=EventSpec,
             _var_data=_var_data,
             _var_value=value,
+            _func=FunctionStringVar("Event"),
+            _args=(
+                # event handler name
+                ".".join(
+                    filter(
+                        None,
+                        format.get_event_handler_parts(value.handler),
+                    )
+                ),
+                # event handler args
+                {str(name): value for name, value in value.args},
+                # event actions
+                value.event_actions,
+                # client handler name
+                *([value.client_handler_name] if value.client_handler_name else []),
+            ),
         )
 
 
@@ -1332,7 +1323,10 @@ class EventChainVar(FunctionVar):
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
-class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
+# Note: LiteralVar is second in the inheritance list allowing it act like a
+# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
+# _cached_var_name property.
+class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
     """A literal event chain var."""
 
     _var_value: EventChain = dataclasses.field(default=None)  # type: ignore
@@ -1345,41 +1339,6 @@ class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
         """
         return hash((self.__class__.__name__, self._js_expr))
 
-    @cached_property_no_lock
-    def _cached_var_name(self) -> str:
-        """The name of the var.
-
-        Returns:
-            The name of the var.
-        """
-        sig = inspect.signature(self._var_value.args_spec)  # type: ignore
-        if sig.parameters:
-            arg_def = tuple((f"_{p}" for p in sig.parameters))
-            arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
-        else:
-            # add a default argument for addEvents if none were specified in value.args_spec
-            # used to trigger the preventDefault() on the event.
-            arg_def = ("...args",)
-            arg_def_expr = Var(_js_expr="args")
-
-        if self._var_value.invocation is None:
-            invocation = FunctionStringVar.create("addEvents")
-        else:
-            invocation = self._var_value.invocation
-
-        return str(
-            ArgsFunctionOperation.create(
-                arg_def,
-                invocation.call(
-                    LiteralVar.create(
-                        [LiteralVar.create(event) for event in self._var_value.events]
-                    ),
-                    arg_def_expr,
-                    self._var_value.event_actions,
-                ),
-            )
-        )
-
     @classmethod
     def create(
         cls,
@@ -1395,10 +1354,31 @@ class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
         Returns:
             The created LiteralEventChainVar instance.
         """
+        sig = inspect.signature(value.args_spec)  # type: ignore
+        if sig.parameters:
+            arg_def = tuple((f"_{p}" for p in sig.parameters))
+            arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
+        else:
+            # add a default argument for addEvents if none were specified in value.args_spec
+            # used to trigger the preventDefault() on the event.
+            arg_def = ("...args",)
+            arg_def_expr = Var(_js_expr="args")
+
+        if value.invocation is None:
+            invocation = FunctionStringVar.create("addEvents")
+        else:
+            invocation = value.invocation
+
         return cls(
             _js_expr="",
             _var_type=EventChain,
             _var_data=_var_data,
+            _args_names=arg_def,
+            _return_expr=invocation.call(
+                LiteralVar.create([LiteralVar.create(event) for event in value.events]),
+                arg_def_expr,
+                value.event_actions,
+            ),
             _var_value=value,
         )
 
@@ -1437,6 +1417,11 @@ EventType = Union[IndividualEventType[G], List[IndividualEventType[G]]]
 
 P = ParamSpec("P")
 T = TypeVar("T")
+V = TypeVar("V")
+V2 = TypeVar("V2")
+V3 = TypeVar("V3")
+V4 = TypeVar("V4")
+V5 = TypeVar("V5")
 
 if sys.version_info >= (3, 10):
     from typing import Concatenate
@@ -1452,7 +1437,55 @@ if sys.version_info >= (3, 10):
             """
             self.func = func
 
-        def __get__(self, instance, owner) -> Callable[P, T]:
+        @overload
+        def __get__(
+            self: EventCallback[[V], T], instance: None, owner
+        ) -> Callable[[Union[Var[V], V]], EventSpec]: ...
+
+        @overload
+        def __get__(
+            self: EventCallback[[V, V2], T], instance: None, owner
+        ) -> Callable[[Union[Var[V], V], Union[Var[V2], V2]], EventSpec]: ...
+
+        @overload
+        def __get__(
+            self: EventCallback[[V, V2, V3], T], instance: None, owner
+        ) -> Callable[
+            [Union[Var[V], V], Union[Var[V2], V2], Union[Var[V3], V3]],
+            EventSpec,
+        ]: ...
+
+        @overload
+        def __get__(
+            self: EventCallback[[V, V2, V3, V4], T], instance: None, owner
+        ) -> Callable[
+            [
+                Union[Var[V], V],
+                Union[Var[V2], V2],
+                Union[Var[V3], V3],
+                Union[Var[V4], V4],
+            ],
+            EventSpec,
+        ]: ...
+
+        @overload
+        def __get__(
+            self: EventCallback[[V, V2, V3, V4, V5], T], instance: None, owner
+        ) -> Callable[
+            [
+                Union[Var[V], V],
+                Union[Var[V2], V2],
+                Union[Var[V3], V3],
+                Union[Var[V4], V4],
+                Union[Var[V5], V5],
+            ],
+            EventSpec,
+        ]: ...
+
+        @overload
+        def __get__(self, instance, owner) -> Callable[P, T]: ...
+
+        def __get__(self, instance, owner) -> Callable:
             """Get the function with the instance bound to it.
 
             Args:

+ 35 - 3
tests/units/test_event.py

@@ -2,11 +2,18 @@ from typing import List
 
 import pytest
 
-from reflex import event
-from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events
+from reflex.event import (
+    Event,
+    EventChain,
+    EventHandler,
+    EventSpec,
+    call_event_handler,
+    event,
+    fix_events,
+)
 from reflex.state import BaseState
 from reflex.utils import format
-from reflex.vars.base import LiteralVar, Var
+from reflex.vars.base import Field, LiteralVar, Var, field
 
 
 def make_var(value) -> Var:
@@ -388,3 +395,28 @@ def test_event_actions_on_state():
     assert sp_handler.event_actions == {"stopPropagation": True}
     # should NOT affect other references to the handler
     assert not handler.event_actions
+
+
+def test_event_var_data():
+    class S(BaseState):
+        x: Field[int] = field(0)
+
+        @event
+        def s(self, value: int):
+            pass
+
+    # Handler doesn't have any _var_data because it's just a str
+    handler_var = Var.create(S.s)
+    assert handler_var._get_all_var_data() is None
+
+    # Ensure spec carries _var_data
+    spec_var = Var.create(S.s(S.x))
+    assert spec_var._get_all_var_data() == S.x._get_all_var_data()
+
+    # Needed to instantiate the EventChain
+    def _args_spec(value: Var[int]) -> tuple[Var[int]]:
+        return (value,)
+
+    # Ensure chain carries _var_data
+    chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec))
+    assert chain_var._get_all_var_data() == S.x._get_all_var_data()