ソースを参照

EventFnArgMismatch fix to support defaults args (#4004)

* EventFnArgMismatch fix to support defaults args

* fixing type hint and docstring raises

* enforce stronger type checking

* unwrap var annotations :(

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
LeoH 7 ヶ月 前
コミット
60276cf1ff
3 ファイル変更74 行追加22 行削除
  1. 56 19
      reflex/event.py
  2. 17 2
      reflex/utils/types.py
  3. 1 1
      tests/units/test_event.py

+ 56 - 19
reflex/event.py

@@ -18,10 +18,12 @@ from typing import (
     get_type_hints,
     get_type_hints,
 )
 )
 
 
+from typing_extensions import get_args, get_origin
+
 from reflex import constants
 from reflex import constants
 from reflex.utils import format
 from reflex.utils import format
 from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
 from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
-from reflex.utils.types import ArgsSpec
+from reflex.utils.types import ArgsSpec, GenericType
 from reflex.vars import VarData
 from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.function import FunctionStringVar, FunctionVar
 from reflex.vars.function import FunctionStringVar, FunctionVar
@@ -417,7 +419,7 @@ class FileUpload:
     on_upload_progress: Optional[Union[EventHandler, Callable]] = None
     on_upload_progress: Optional[Union[EventHandler, Callable]] = None
 
 
     @staticmethod
     @staticmethod
-    def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]):
+    def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]):
         """Args spec for on_upload_progress event handler.
         """Args spec for on_upload_progress event handler.
 
 
         Returns:
         Returns:
@@ -910,6 +912,20 @@ def call_event_handler(
         )
         )
 
 
 
 
+def unwrap_var_annotation(annotation: GenericType):
+    """Unwrap a Var annotation or return it as is if it's not Var[X].
+
+    Args:
+        annotation: The annotation to unwrap.
+
+    Returns:
+        The unwrapped annotation.
+    """
+    if get_origin(annotation) is Var and (args := get_args(annotation)):
+        return args[0]
+    return annotation
+
+
 def parse_args_spec(arg_spec: ArgsSpec):
 def parse_args_spec(arg_spec: ArgsSpec):
     """Parse the args provided in the ArgsSpec of an event trigger.
     """Parse the args provided in the ArgsSpec of an event trigger.
 
 
@@ -921,20 +937,54 @@ def parse_args_spec(arg_spec: ArgsSpec):
     """
     """
     spec = inspect.getfullargspec(arg_spec)
     spec = inspect.getfullargspec(arg_spec)
     annotations = get_type_hints(arg_spec)
     annotations = get_type_hints(arg_spec)
+
     return arg_spec(
     return arg_spec(
         *[
         *[
-            Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent))
+            Var(f"_{l_arg}").to(
+                unwrap_var_annotation(annotations.get(l_arg, FrontendEvent))
+            )
             for l_arg in spec.args
             for l_arg in spec.args
         ]
         ]
     )
     )
 
 
 
 
+def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
+    """Ensures that the function signature matches the passed argument specification
+    or raises an EventFnArgMismatch if they do not.
+
+    Args:
+        fn: The function to be validated.
+        arg_spec: The argument specification for the event trigger.
+
+    Returns:
+        The parsed arguments from the argument specification.
+
+    Raises:
+        EventFnArgMismatch: Raised if the number of mandatory arguments do not match
+    """
+    fn_args = inspect.getfullargspec(fn).args
+    fn_defaults_args = inspect.getfullargspec(fn).defaults
+    n_fn_args = len(fn_args)
+    n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
+    if isinstance(fn, types.MethodType):
+        n_fn_args -= 1  # subtract 1 for bound self arg
+    parsed_args = parse_args_spec(arg_spec)
+    if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
+        raise EventFnArgMismatch(
+            "The number of mandatory arguments accepted by "
+            f"{fn} ({n_fn_args - n_fn_defaults_args}) "
+            "does not match the arguments passed by the event trigger: "
+            f"{[str(v) for v in parsed_args]}\n"
+            "See https://reflex.dev/docs/events/event-arguments/"
+        )
+    return parsed_args
+
+
 def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
 def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     """Call a function to a list of event specs.
     """Call a function to a list of event specs.
 
 
     The function should return a single EventSpec, a list of EventSpecs, or a
     The function should return a single EventSpec, a list of EventSpecs, or a
-    single Var. The function signature must match the passed arg_spec or
-    EventFnArgsMismatch will be raised.
+    single Var.
 
 
     Args:
     Args:
         fn: The function to call.
         fn: The function to call.
@@ -944,7 +994,6 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
         The event specs from calling the function or a Var.
         The event specs from calling the function or a Var.
 
 
     Raises:
     Raises:
-        EventFnArgMismatch: If the function signature doesn't match the arg spec.
         EventHandlerValueError: If the lambda returns an unusable value.
         EventHandlerValueError: If the lambda returns an unusable value.
     """
     """
     # Import here to avoid circular imports.
     # Import here to avoid circular imports.
@@ -952,19 +1001,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     from reflex.utils.exceptions import EventHandlerValueError
     from reflex.utils.exceptions import EventHandlerValueError
 
 
     # Check that fn signature matches arg_spec
     # Check that fn signature matches arg_spec
-    fn_args = inspect.getfullargspec(fn).args
-    n_fn_args = len(fn_args)
-    if isinstance(fn, types.MethodType):
-        n_fn_args -= 1  # subtract 1 for bound self arg
-    parsed_args = parse_args_spec(arg_spec)
-    if len(parsed_args) != n_fn_args:
-        raise EventFnArgMismatch(
-            "The number of arguments accepted by "
-            f"{fn} ({n_fn_args}) "
-            "does not match the arguments passed by the event trigger: "
-            f"{[str(v) for v in parsed_args]}\n"
-            "See https://reflex.dev/docs/events/event-arguments/"
-        )
+    parsed_args = check_fn_match_arg_spec(fn, arg_spec)
 
 
     # Call the function with the parsed args.
     # Call the function with the parsed args.
     out = fn(*parsed_args)
     out = fn(*parsed_args)

+ 17 - 2
reflex/utils/types.py

@@ -9,6 +9,7 @@ import sys
 import types
 import types
 from functools import cached_property, lru_cache, wraps
 from functools import cached_property, lru_cache, wraps
 from typing import (
 from typing import (
+    TYPE_CHECKING,
     Any,
     Any,
     Callable,
     Callable,
     ClassVar,
     ClassVar,
@@ -96,8 +97,22 @@ PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
 StateVar = Union[PrimitiveType, Base, None]
 StateVar = Union[PrimitiveType, Base, None]
 StateIterVar = Union[list, set, tuple]
 StateIterVar = Union[list, set, tuple]
 
 
-# ArgsSpec = Callable[[Var], list[Var]]
-ArgsSpec = Callable
+if TYPE_CHECKING:
+    from reflex.vars.base import Var
+
+    # ArgsSpec = Callable[[Var], list[Var]]
+    ArgsSpec = (
+        Callable[[], List[Var]]
+        | Callable[[Var], List[Var]]
+        | Callable[[Var, Var], List[Var]]
+        | Callable[[Var, Var, Var], List[Var]]
+        | Callable[[Var, Var, Var, Var], List[Var]]
+        | Callable[[Var, Var, Var, Var, Var], List[Var]]
+        | Callable[[Var, Var, Var, Var, Var, Var], List[Var]]
+        | Callable[[Var, Var, Var, Var, Var, Var, Var], List[Var]]
+    )
+else:
+    ArgsSpec = Callable[..., List[Any]]
 
 
 
 
 PrimitiveToAnnotation = {
 PrimitiveToAnnotation = {

+ 1 - 1
tests/units/test_event.py

@@ -97,7 +97,7 @@ def test_call_event_handler_partial():
 
 
     test_fn_with_args.__qualname__ = "test_fn_with_args"
     test_fn_with_args.__qualname__ = "test_fn_with_args"
 
 
-    def spec(a2: str) -> List[str]:
+    def spec(a2: Var[str]) -> List[Var[str]]:
         return [a2]
         return [a2]
 
 
     handler = EventHandler(fn=test_fn_with_args)
     handler = EventHandler(fn=test_fn_with_args)