Browse Source

EventFnArgMismatch fix to support defaults args (#4004)

* EventFnArgMismatch fix to support defaults args

* fixing type hint and docstring raises

* enforce stronger type checking

* unwrap var annotations :(

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
LeoH 7 months ago
parent
commit
60276cf1ff
3 changed files with 74 additions and 22 deletions
  1. 56 19
      reflex/event.py
  2. 17 2
      reflex/utils/types.py
  3. 1 1
      tests/units/test_event.py

+ 56 - 19
reflex/event.py

@@ -18,10 +18,12 @@ from typing import (
     get_type_hints,
 )
 
+from typing_extensions import get_args, get_origin
+
 from reflex import constants
 from reflex.utils import format
 from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
-from reflex.utils.types import ArgsSpec
+from reflex.utils.types import ArgsSpec, GenericType
 from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.function import FunctionStringVar, FunctionVar
@@ -417,7 +419,7 @@ class FileUpload:
     on_upload_progress: Optional[Union[EventHandler, Callable]] = None
 
     @staticmethod
-    def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]):
+    def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]):
         """Args spec for on_upload_progress event handler.
 
         Returns:
@@ -910,6 +912,20 @@ def call_event_handler(
         )
 
 
+def unwrap_var_annotation(annotation: GenericType):
+    """Unwrap a Var annotation or return it as is if it's not Var[X].
+
+    Args:
+        annotation: The annotation to unwrap.
+
+    Returns:
+        The unwrapped annotation.
+    """
+    if get_origin(annotation) is Var and (args := get_args(annotation)):
+        return args[0]
+    return annotation
+
+
 def parse_args_spec(arg_spec: ArgsSpec):
     """Parse the args provided in the ArgsSpec of an event trigger.
 
@@ -921,20 +937,54 @@ def parse_args_spec(arg_spec: ArgsSpec):
     """
     spec = inspect.getfullargspec(arg_spec)
     annotations = get_type_hints(arg_spec)
+
     return arg_spec(
         *[
-            Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent))
+            Var(f"_{l_arg}").to(
+                unwrap_var_annotation(annotations.get(l_arg, FrontendEvent))
+            )
             for l_arg in spec.args
         ]
     )
 
 
+def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
+    """Ensures that the function signature matches the passed argument specification
+    or raises an EventFnArgMismatch if they do not.
+
+    Args:
+        fn: The function to be validated.
+        arg_spec: The argument specification for the event trigger.
+
+    Returns:
+        The parsed arguments from the argument specification.
+
+    Raises:
+        EventFnArgMismatch: Raised if the number of mandatory arguments do not match
+    """
+    fn_args = inspect.getfullargspec(fn).args
+    fn_defaults_args = inspect.getfullargspec(fn).defaults
+    n_fn_args = len(fn_args)
+    n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
+    if isinstance(fn, types.MethodType):
+        n_fn_args -= 1  # subtract 1 for bound self arg
+    parsed_args = parse_args_spec(arg_spec)
+    if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
+        raise EventFnArgMismatch(
+            "The number of mandatory arguments accepted by "
+            f"{fn} ({n_fn_args - n_fn_defaults_args}) "
+            "does not match the arguments passed by the event trigger: "
+            f"{[str(v) for v in parsed_args]}\n"
+            "See https://reflex.dev/docs/events/event-arguments/"
+        )
+    return parsed_args
+
+
 def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     """Call a function to a list of event specs.
 
     The function should return a single EventSpec, a list of EventSpecs, or a
-    single Var. The function signature must match the passed arg_spec or
-    EventFnArgsMismatch will be raised.
+    single Var.
 
     Args:
         fn: The function to call.
@@ -944,7 +994,6 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
         The event specs from calling the function or a Var.
 
     Raises:
-        EventFnArgMismatch: If the function signature doesn't match the arg spec.
         EventHandlerValueError: If the lambda returns an unusable value.
     """
     # Import here to avoid circular imports.
@@ -952,19 +1001,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
     from reflex.utils.exceptions import EventHandlerValueError
 
     # Check that fn signature matches arg_spec
-    fn_args = inspect.getfullargspec(fn).args
-    n_fn_args = len(fn_args)
-    if isinstance(fn, types.MethodType):
-        n_fn_args -= 1  # subtract 1 for bound self arg
-    parsed_args = parse_args_spec(arg_spec)
-    if len(parsed_args) != n_fn_args:
-        raise EventFnArgMismatch(
-            "The number of arguments accepted by "
-            f"{fn} ({n_fn_args}) "
-            "does not match the arguments passed by the event trigger: "
-            f"{[str(v) for v in parsed_args]}\n"
-            "See https://reflex.dev/docs/events/event-arguments/"
-        )
+    parsed_args = check_fn_match_arg_spec(fn, arg_spec)
 
     # Call the function with the parsed args.
     out = fn(*parsed_args)

+ 17 - 2
reflex/utils/types.py

@@ -9,6 +9,7 @@ import sys
 import types
 from functools import cached_property, lru_cache, wraps
 from typing import (
+    TYPE_CHECKING,
     Any,
     Callable,
     ClassVar,
@@ -96,8 +97,22 @@ PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
 StateVar = Union[PrimitiveType, Base, None]
 StateIterVar = Union[list, set, tuple]
 
-# ArgsSpec = Callable[[Var], list[Var]]
-ArgsSpec = Callable
+if TYPE_CHECKING:
+    from reflex.vars.base import Var
+
+    # ArgsSpec = Callable[[Var], list[Var]]
+    ArgsSpec = (
+        Callable[[], List[Var]]
+        | Callable[[Var], List[Var]]
+        | Callable[[Var, Var], List[Var]]
+        | Callable[[Var, Var, Var], List[Var]]
+        | Callable[[Var, Var, Var, Var], List[Var]]
+        | Callable[[Var, Var, Var, Var, Var], List[Var]]
+        | Callable[[Var, Var, Var, Var, Var, Var], List[Var]]
+        | Callable[[Var, Var, Var, Var, Var, Var, Var], List[Var]]
+    )
+else:
+    ArgsSpec = Callable[..., List[Any]]
 
 
 PrimitiveToAnnotation = {

+ 1 - 1
tests/units/test_event.py

@@ -97,7 +97,7 @@ def test_call_event_handler_partial():
 
     test_fn_with_args.__qualname__ = "test_fn_with_args"
 
-    def spec(a2: str) -> List[str]:
+    def spec(a2: Var[str]) -> List[Var[str]]:
         return [a2]
 
     handler = EventHandler(fn=test_fn_with_args)