|
@@ -29,8 +29,12 @@ from typing_extensions import ParamSpec, Protocol, get_args, get_origin
|
|
|
|
|
|
from reflex import constants
|
|
|
from reflex.utils import console, format
|
|
|
-from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
|
|
|
-from reflex.utils.types import ArgsSpec, GenericType
|
|
|
+from reflex.utils.exceptions import (
|
|
|
+ EventFnArgMismatch,
|
|
|
+ EventHandlerArgMismatch,
|
|
|
+ EventHandlerArgTypeMismatch,
|
|
|
+)
|
|
|
+from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
|
|
|
from reflex.vars import VarData
|
|
|
from reflex.vars.base import (
|
|
|
LiteralVar,
|
|
@@ -401,7 +405,9 @@ class EventChain(EventActionsMixin):
|
|
|
default_factory=list
|
|
|
)
|
|
|
|
|
|
- args_spec: Optional[Callable] = dataclasses.field(default=None)
|
|
|
+ args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field(
|
|
|
+ default=None
|
|
|
+ )
|
|
|
|
|
|
invocation: Optional[Var] = dataclasses.field(default=None)
|
|
|
|
|
@@ -1053,7 +1059,8 @@ def get_hydrate_event(state) -> str:
|
|
|
|
|
|
def call_event_handler(
|
|
|
event_handler: EventHandler | EventSpec,
|
|
|
- arg_spec: ArgsSpec,
|
|
|
+ arg_spec: ArgsSpec | Sequence[ArgsSpec],
|
|
|
+ key: Optional[str] = None,
|
|
|
) -> EventSpec:
|
|
|
"""Call an event handler to get the event spec.
|
|
|
|
|
@@ -1064,12 +1071,16 @@ def call_event_handler(
|
|
|
Args:
|
|
|
event_handler: The event handler.
|
|
|
arg_spec: The lambda that define the argument(s) to pass to the event handler.
|
|
|
+ key: The key to pass to the event handler.
|
|
|
|
|
|
Raises:
|
|
|
EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
|
|
|
|
|
|
Returns:
|
|
|
The event spec from calling the event handler.
|
|
|
+
|
|
|
+ # noqa: DAR401 failure
|
|
|
+
|
|
|
"""
|
|
|
parsed_args = parse_args_spec(arg_spec) # type: ignore
|
|
|
|
|
@@ -1077,19 +1088,113 @@ def call_event_handler(
|
|
|
# Handle partial application of EventSpec args
|
|
|
return event_handler.add_args(*parsed_args)
|
|
|
|
|
|
- args = inspect.getfullargspec(event_handler.fn).args
|
|
|
- n_args = len(args) - 1 # subtract 1 for bound self arg
|
|
|
- if n_args == len(parsed_args):
|
|
|
- return event_handler(*parsed_args) # type: ignore
|
|
|
- else:
|
|
|
+ provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)
|
|
|
+
|
|
|
+ provided_callback_n_args = (
|
|
|
+ len(provided_callback_fullspec.args) - 1
|
|
|
+ ) # subtract 1 for bound self arg
|
|
|
+
|
|
|
+ if provided_callback_n_args != len(parsed_args):
|
|
|
raise EventHandlerArgMismatch(
|
|
|
"The number of arguments accepted by "
|
|
|
- f"{event_handler.fn.__qualname__} ({n_args}) "
|
|
|
+ f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
|
|
|
"does not match the arguments passed by the event trigger: "
|
|
|
f"{[str(v) for v in parsed_args]}\n"
|
|
|
"See https://reflex.dev/docs/events/event-arguments/"
|
|
|
)
|
|
|
|
|
|
+ all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
|
|
|
+
|
|
|
+ event_spec_return_types = list(
|
|
|
+ filter(
|
|
|
+ lambda event_spec_return_type: event_spec_return_type is not None
|
|
|
+ and get_origin(event_spec_return_type) is tuple,
|
|
|
+ (get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ if event_spec_return_types:
|
|
|
+ failures = []
|
|
|
+
|
|
|
+ for event_spec_index, event_spec_return_type in enumerate(
|
|
|
+ event_spec_return_types
|
|
|
+ ):
|
|
|
+ args = get_args(event_spec_return_type)
|
|
|
+
|
|
|
+ args_types_without_vars = [
|
|
|
+ arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
|
|
|
+ ]
|
|
|
+
|
|
|
+ try:
|
|
|
+ type_hints_of_provided_callback = get_type_hints(event_handler.fn)
|
|
|
+ except NameError:
|
|
|
+ type_hints_of_provided_callback = {}
|
|
|
+
|
|
|
+ failed_type_check = False
|
|
|
+
|
|
|
+ # check that args of event handler are matching the spec if type hints are provided
|
|
|
+ for i, arg in enumerate(provided_callback_fullspec.args[1:]):
|
|
|
+ if arg not in type_hints_of_provided_callback:
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ compare_result = typehint_issubclass(
|
|
|
+ args_types_without_vars[i], type_hints_of_provided_callback[arg]
|
|
|
+ )
|
|
|
+ except TypeError:
|
|
|
+ # TODO: In 0.7.0, remove this block and raise the exception
|
|
|
+ # raise TypeError(
|
|
|
+ # f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
|
|
|
+ # ) from e
|
|
|
+ console.warn(
|
|
|
+ f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
|
|
|
+ )
|
|
|
+ compare_result = False
|
|
|
+
|
|
|
+ if compare_result:
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ failure = EventHandlerArgTypeMismatch(
|
|
|
+ f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
|
|
|
+ )
|
|
|
+ failures.append(failure)
|
|
|
+ failed_type_check = True
|
|
|
+ break
|
|
|
+
|
|
|
+ if not failed_type_check:
|
|
|
+ if event_spec_index:
|
|
|
+ args = get_args(event_spec_return_types[0])
|
|
|
+
|
|
|
+ args_types_without_vars = [
|
|
|
+ arg if get_origin(arg) is not Var else get_args(arg)[0]
|
|
|
+ for arg in args
|
|
|
+ ]
|
|
|
+
|
|
|
+ expect_string = ", ".join(
|
|
|
+ repr(arg) for arg in args_types_without_vars
|
|
|
+ ).replace("[", "\\[")
|
|
|
+
|
|
|
+ given_string = ", ".join(
|
|
|
+ repr(type_hints_of_provided_callback.get(arg, Any))
|
|
|
+ for arg in provided_callback_fullspec.args[1:]
|
|
|
+ ).replace("[", "\\[")
|
|
|
+
|
|
|
+ console.warn(
|
|
|
+ f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
|
|
|
+ f"This may lead to unexpected behavior but is intentionally ignored for {key}."
|
|
|
+ )
|
|
|
+ return event_handler(*parsed_args)
|
|
|
+
|
|
|
+ if failures:
|
|
|
+ console.deprecate(
|
|
|
+ "Mismatched event handler argument types",
|
|
|
+ "\n".join([str(f) for f in failures]),
|
|
|
+ "0.6.5",
|
|
|
+ "0.7.0",
|
|
|
+ )
|
|
|
+
|
|
|
+ return event_handler(*parsed_args) # type: ignore
|
|
|
+
|
|
|
|
|
|
def unwrap_var_annotation(annotation: GenericType):
|
|
|
"""Unwrap a Var annotation or return it as is if it's not Var[X].
|
|
@@ -1128,7 +1233,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
|
|
|
return annotation
|
|
|
|
|
|
|
|
|
-def parse_args_spec(arg_spec: ArgsSpec):
|
|
|
+def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
|
|
|
"""Parse the args provided in the ArgsSpec of an event trigger.
|
|
|
|
|
|
Args:
|
|
@@ -1137,6 +1242,8 @@ def parse_args_spec(arg_spec: ArgsSpec):
|
|
|
Returns:
|
|
|
The parsed args.
|
|
|
"""
|
|
|
+ # if there's multiple, the first is the default
|
|
|
+ arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec
|
|
|
spec = inspect.getfullargspec(arg_spec)
|
|
|
annotations = get_type_hints(arg_spec)
|
|
|
|
|
@@ -1152,13 +1259,18 @@ def parse_args_spec(arg_spec: ArgsSpec):
|
|
|
)
|
|
|
|
|
|
|
|
|
-def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
|
|
|
+def check_fn_match_arg_spec(
|
|
|
+ fn: Callable,
|
|
|
+ arg_spec: ArgsSpec,
|
|
|
+ key: Optional[str] = None,
|
|
|
+) -> List[Var]:
|
|
|
"""Ensures that the function signature matches the passed argument specification
|
|
|
or raises an EventFnArgMismatch if they do not.
|
|
|
|
|
|
Args:
|
|
|
fn: The function to be validated.
|
|
|
arg_spec: The argument specification for the event trigger.
|
|
|
+ key: The key to pass to the event handler.
|
|
|
|
|
|
Returns:
|
|
|
The parsed arguments from the argument specification.
|
|
@@ -1184,7 +1296,11 @@ def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
|
|
|
return parsed_args
|
|
|
|
|
|
|
|
|
-def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
|
|
+def call_event_fn(
|
|
|
+ fn: Callable,
|
|
|
+ arg_spec: ArgsSpec,
|
|
|
+ key: Optional[str] = None,
|
|
|
+) -> list[EventSpec] | Var:
|
|
|
"""Call a function to a list of event specs.
|
|
|
|
|
|
The function should return a single EventSpec, a list of EventSpecs, or a
|
|
@@ -1193,6 +1309,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
|
|
Args:
|
|
|
fn: The function to call.
|
|
|
arg_spec: The argument spec for the event trigger.
|
|
|
+ key: The key to pass to the event handler.
|
|
|
|
|
|
Returns:
|
|
|
The event specs from calling the function or a Var.
|
|
@@ -1205,7 +1322,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
|
|
from reflex.utils.exceptions import EventHandlerValueError
|
|
|
|
|
|
# Check that fn signature matches arg_spec
|
|
|
- parsed_args = check_fn_match_arg_spec(fn, arg_spec)
|
|
|
+ parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
|
|
|
|
|
|
# Call the function with the parsed args.
|
|
|
out = fn(*parsed_args)
|
|
@@ -1223,7 +1340,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
|
|
for e in out:
|
|
|
if isinstance(e, EventHandler):
|
|
|
# An un-called EventHandler gets all of the args of the event trigger.
|
|
|
- e = call_event_handler(e, arg_spec)
|
|
|
+ e = call_event_handler(e, arg_spec, key=key)
|
|
|
|
|
|
# Make sure the event spec is valid.
|
|
|
if not isinstance(e, EventSpec):
|
|
@@ -1433,7 +1550,12 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
|
|
|
Returns:
|
|
|
The created LiteralEventChainVar instance.
|
|
|
"""
|
|
|
- sig = inspect.signature(value.args_spec) # type: ignore
|
|
|
+ arg_spec = (
|
|
|
+ value.args_spec[0]
|
|
|
+ if isinstance(value.args_spec, Sequence)
|
|
|
+ else value.args_spec
|
|
|
+ )
|
|
|
+ sig = inspect.signature(arg_spec) # type: ignore
|
|
|
if sig.parameters:
|
|
|
arg_def = tuple((f"_{p}" for p in sig.parameters))
|
|
|
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
|