|
@@ -18,10 +18,12 @@ from typing import (
|
|
|
get_type_hints,
|
|
|
)
|
|
|
|
|
|
+from typing_extensions import get_args, get_origin
|
|
|
+
|
|
|
from reflex import constants
|
|
|
from reflex.utils import format
|
|
|
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
|
|
|
-from reflex.utils.types import ArgsSpec
|
|
|
+from reflex.utils.types import ArgsSpec, GenericType
|
|
|
from reflex.vars import VarData
|
|
|
from reflex.vars.base import LiteralVar, Var
|
|
|
from reflex.vars.function import FunctionStringVar, FunctionVar
|
|
@@ -417,7 +419,7 @@ class FileUpload:
|
|
|
on_upload_progress: Optional[Union[EventHandler, Callable]] = None
|
|
|
|
|
|
@staticmethod
|
|
|
- def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]):
|
|
|
+ def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]):
|
|
|
"""Args spec for on_upload_progress event handler.
|
|
|
|
|
|
Returns:
|
|
@@ -910,6 +912,20 @@ def call_event_handler(
|
|
|
)
|
|
|
|
|
|
|
|
|
+def unwrap_var_annotation(annotation: GenericType):
|
|
|
+ """Unwrap a Var annotation or return it as is if it's not Var[X].
|
|
|
+
|
|
|
+ Args:
|
|
|
+ annotation: The annotation to unwrap.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The unwrapped annotation.
|
|
|
+ """
|
|
|
+ if get_origin(annotation) is Var and (args := get_args(annotation)):
|
|
|
+ return args[0]
|
|
|
+ return annotation
|
|
|
+
|
|
|
+
|
|
|
def parse_args_spec(arg_spec: ArgsSpec):
|
|
|
"""Parse the args provided in the ArgsSpec of an event trigger.
|
|
|
|
|
@@ -921,20 +937,54 @@ def parse_args_spec(arg_spec: ArgsSpec):
|
|
|
"""
|
|
|
spec = inspect.getfullargspec(arg_spec)
|
|
|
annotations = get_type_hints(arg_spec)
|
|
|
+
|
|
|
return arg_spec(
|
|
|
*[
|
|
|
- Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent))
|
|
|
+ Var(f"_{l_arg}").to(
|
|
|
+ unwrap_var_annotation(annotations.get(l_arg, FrontendEvent))
|
|
|
+ )
|
|
|
for l_arg in spec.args
|
|
|
]
|
|
|
)
|
|
|
|
|
|
|
|
|
+def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
|
|
|
+ """Ensures that the function signature matches the passed argument specification
|
|
|
+ or raises an EventFnArgMismatch if they do not.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ fn: The function to be validated.
|
|
|
+ arg_spec: The argument specification for the event trigger.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The parsed arguments from the argument specification.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ EventFnArgMismatch: Raised if the number of mandatory arguments do not match
|
|
|
+ """
|
|
|
+ fn_args = inspect.getfullargspec(fn).args
|
|
|
+ fn_defaults_args = inspect.getfullargspec(fn).defaults
|
|
|
+ n_fn_args = len(fn_args)
|
|
|
+ n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
|
|
|
+ if isinstance(fn, types.MethodType):
|
|
|
+ n_fn_args -= 1 # subtract 1 for bound self arg
|
|
|
+ parsed_args = parse_args_spec(arg_spec)
|
|
|
+ if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
|
|
|
+ raise EventFnArgMismatch(
|
|
|
+ "The number of mandatory arguments accepted by "
|
|
|
+ f"{fn} ({n_fn_args - n_fn_defaults_args}) "
|
|
|
+ "does not match the arguments passed by the event trigger: "
|
|
|
+ f"{[str(v) for v in parsed_args]}\n"
|
|
|
+ "See https://reflex.dev/docs/events/event-arguments/"
|
|
|
+ )
|
|
|
+ return parsed_args
|
|
|
+
|
|
|
+
|
|
|
def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
|
|
"""Call a function to a list of event specs.
|
|
|
|
|
|
The function should return a single EventSpec, a list of EventSpecs, or a
|
|
|
- single Var. The function signature must match the passed arg_spec or
|
|
|
- EventFnArgsMismatch will be raised.
|
|
|
+ single Var.
|
|
|
|
|
|
Args:
|
|
|
fn: The function to call.
|
|
@@ -944,7 +994,6 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
|
|
The event specs from calling the function or a Var.
|
|
|
|
|
|
Raises:
|
|
|
- EventFnArgMismatch: If the function signature doesn't match the arg spec.
|
|
|
EventHandlerValueError: If the lambda returns an unusable value.
|
|
|
"""
|
|
|
# Import here to avoid circular imports.
|
|
@@ -952,19 +1001,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
|
|
from reflex.utils.exceptions import EventHandlerValueError
|
|
|
|
|
|
# Check that fn signature matches arg_spec
|
|
|
- fn_args = inspect.getfullargspec(fn).args
|
|
|
- n_fn_args = len(fn_args)
|
|
|
- if isinstance(fn, types.MethodType):
|
|
|
- n_fn_args -= 1 # subtract 1 for bound self arg
|
|
|
- parsed_args = parse_args_spec(arg_spec)
|
|
|
- if len(parsed_args) != n_fn_args:
|
|
|
- raise EventFnArgMismatch(
|
|
|
- "The number of arguments accepted by "
|
|
|
- f"{fn} ({n_fn_args}) "
|
|
|
- "does not match the arguments passed by the event trigger: "
|
|
|
- f"{[str(v) for v in parsed_args]}\n"
|
|
|
- "See https://reflex.dev/docs/events/event-arguments/"
|
|
|
- )
|
|
|
+ parsed_args = check_fn_match_arg_spec(fn, arg_spec)
|
|
|
|
|
|
# Call the function with the parsed args.
|
|
|
out = fn(*parsed_args)
|