|
@@ -1262,6 +1262,118 @@ def get_hydrate_event(state: BaseState) -> str:
|
|
|
return get_event(state, constants.CompileVars.HYDRATE)
|
|
|
|
|
|
|
|
|
+def _values_returned_from_event(
|
|
|
+ event_spec: ArgsSpec | Sequence[ArgsSpec],
|
|
|
+) -> list[Any]:
|
|
|
+ return [
|
|
|
+ event_spec_return_type
|
|
|
+ for arg_spec in (
|
|
|
+ [event_spec] if not isinstance(event_spec, Sequence) else list(event_spec)
|
|
|
+ )
|
|
|
+ if (event_spec_return_type := get_type_hints(arg_spec).get("return", None))
|
|
|
+ is not None
|
|
|
+ and get_origin(event_spec_return_type) is tuple
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+def _check_event_args_subclass_of_callback(
|
|
|
+ callback_params_names: list[str],
|
|
|
+ provided_event_types: list[Any],
|
|
|
+ callback_param_name_to_type: dict[str, Any],
|
|
|
+ callback_name: str = "",
|
|
|
+ key: str = "",
|
|
|
+):
|
|
|
+ """Check if the event handler arguments are subclass of the callback.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ callback_params_names: The names of the callback parameters.
|
|
|
+ provided_event_types: The event types.
|
|
|
+ callback_param_name_to_type: The callback parameter name to type mapping.
|
|
|
+ callback_name: The name of the callback.
|
|
|
+ key: The key.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ TypeError: If the event handler arguments are invalid.
|
|
|
+ EventHandlerArgTypeMismatchError: If the event handler arguments do not match the callback.
|
|
|
+
|
|
|
+ # noqa: DAR401 delayed_exceptions[]
|
|
|
+ # noqa: DAR402 EventHandlerArgTypeMismatchError
|
|
|
+ """
|
|
|
+ type_match_found: dict[str, bool] = {}
|
|
|
+ delayed_exceptions: list[EventHandlerArgTypeMismatchError] = []
|
|
|
+
|
|
|
+ for event_spec_index, event_spec_return_type in enumerate(provided_event_types):
|
|
|
+ args = get_args(event_spec_return_type)
|
|
|
+
|
|
|
+ args_types_without_vars = [
|
|
|
+ arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
|
|
|
+ ]
|
|
|
+
|
|
|
+ # check that args of event handler are matching the spec if type hints are provided
|
|
|
+ for i, arg in enumerate(callback_params_names[: len(args_types_without_vars)]):
|
|
|
+ if arg not in callback_param_name_to_type:
|
|
|
+ continue
|
|
|
+
|
|
|
+ type_match_found.setdefault(arg, False)
|
|
|
+
|
|
|
+ try:
|
|
|
+ compare_result = typehint_issubclass(
|
|
|
+ args_types_without_vars[i], callback_param_name_to_type[arg]
|
|
|
+ )
|
|
|
+ except TypeError as te:
|
|
|
+ callback_name_context = f" of {callback_name}" if callback_name else ""
|
|
|
+ key_context = f" for {key}" if key else ""
|
|
|
+ raise TypeError(
|
|
|
+ f"Could not compare types {args_types_without_vars[i]} and {callback_param_name_to_type[arg]} for argument {arg}{callback_name_context}{key_context}."
|
|
|
+ ) from te
|
|
|
+
|
|
|
+ if compare_result:
|
|
|
+ type_match_found[arg] = True
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ type_match_found[arg] = False
|
|
|
+ as_annotated_in = (
|
|
|
+ f" as annotated in {callback_name}" if callback_name else ""
|
|
|
+ )
|
|
|
+ delayed_exceptions.append(
|
|
|
+ EventHandlerArgTypeMismatchError(
|
|
|
+ f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {callback_param_name_to_type[arg]}{as_annotated_in} instead."
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ if all(type_match_found.values()):
|
|
|
+ delayed_exceptions.clear()
|
|
|
+ if event_spec_index:
|
|
|
+ args = get_args(provided_event_types[0])
|
|
|
+
|
|
|
+ args_types_without_vars = [
|
|
|
+ arg if get_origin(arg) is not Var else get_args(arg)[0]
|
|
|
+ for arg in args
|
|
|
+ ]
|
|
|
+
|
|
|
+ expect_string = ", ".join(
|
|
|
+ repr(arg) for arg in args_types_without_vars
|
|
|
+ ).replace("[", "\\[")
|
|
|
+
|
|
|
+ given_string = ", ".join(
|
|
|
+ repr(callback_param_name_to_type.get(arg, Any))
|
|
|
+ for arg in callback_params_names
|
|
|
+ ).replace("[", "\\[")
|
|
|
+
|
|
|
+ as_annotated_in = (
|
|
|
+ f" as annotated in {callback_name}" if callback_name else ""
|
|
|
+ )
|
|
|
+
|
|
|
+ console.warn(
|
|
|
+ f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> (){as_annotated_in} instead. "
|
|
|
+ f"This may lead to unexpected behavior but is intentionally ignored for {key}."
|
|
|
+ )
|
|
|
+ break
|
|
|
+
|
|
|
+ if delayed_exceptions:
|
|
|
+ raise delayed_exceptions[0]
|
|
|
+
|
|
|
+
|
|
|
def call_event_handler(
|
|
|
event_callback: EventHandler | EventSpec,
|
|
|
event_spec: ArgsSpec | Sequence[ArgsSpec],
|
|
@@ -1278,17 +1390,13 @@ def call_event_handler(
|
|
|
event_spec: The lambda that define the argument(s) to pass to the event handler.
|
|
|
key: The key to pass to the event handler.
|
|
|
|
|
|
- Raises:
|
|
|
- EventHandlerArgTypeMismatchError: If the event handler arguments do not match the event spec. #noqa: DAR402
|
|
|
- TypeError: If the event handler arguments are invalid.
|
|
|
-
|
|
|
Returns:
|
|
|
The event spec from calling the event handler.
|
|
|
-
|
|
|
- #noqa: DAR401
|
|
|
"""
|
|
|
event_spec_args = parse_args_spec(event_spec)
|
|
|
|
|
|
+ event_spec_return_types = _values_returned_from_event(event_spec)
|
|
|
+
|
|
|
if isinstance(event_callback, EventSpec):
|
|
|
check_fn_match_arg_spec(
|
|
|
event_callback.handler.fn,
|
|
@@ -1297,6 +1405,32 @@ def call_event_handler(
|
|
|
bool(event_callback.handler.state_full_name) + len(event_callback.args),
|
|
|
event_callback.handler.fn.__qualname__,
|
|
|
)
|
|
|
+
|
|
|
+ event_callback_spec_args = list(
|
|
|
+ inspect.signature(event_callback.handler.fn).parameters.keys()
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ type_hints_of_provided_callback = get_type_hints(event_callback.handler.fn)
|
|
|
+ except NameError:
|
|
|
+ type_hints_of_provided_callback = {}
|
|
|
+
|
|
|
+ argument_names = [str(arg) for arg, value in event_callback.args]
|
|
|
+
|
|
|
+ _check_event_args_subclass_of_callback(
|
|
|
+ [
|
|
|
+ arg
|
|
|
+ for arg in event_callback_spec_args[
|
|
|
+ bool(event_callback.handler.state_full_name) :
|
|
|
+ ]
|
|
|
+ if arg not in argument_names
|
|
|
+ ],
|
|
|
+ event_spec_return_types,
|
|
|
+ type_hints_of_provided_callback,
|
|
|
+ event_callback.handler.fn.__qualname__,
|
|
|
+ key or "",
|
|
|
+ )
|
|
|
+
|
|
|
# Handle partial application of EventSpec args
|
|
|
return event_callback.add_args(*event_spec_args)
|
|
|
|
|
@@ -1308,98 +1442,23 @@ def call_event_handler(
|
|
|
event_callback.fn.__qualname__,
|
|
|
)
|
|
|
|
|
|
- all_acceptable_specs = (
|
|
|
- [event_spec] if not isinstance(event_spec, Sequence) else event_spec
|
|
|
- )
|
|
|
-
|
|
|
- event_spec_return_types = list(
|
|
|
- filter(
|
|
|
- lambda event_spec_return_type: event_spec_return_type is not None
|
|
|
- and get_origin(event_spec_return_type) is tuple,
|
|
|
- (
|
|
|
- get_type_hints(arg_spec).get("return", None)
|
|
|
- for arg_spec in all_acceptable_specs
|
|
|
- ),
|
|
|
- )
|
|
|
- )
|
|
|
- type_match_found: dict[str, bool] = {}
|
|
|
- delayed_exceptions: list[EventHandlerArgTypeMismatchError] = []
|
|
|
-
|
|
|
- try:
|
|
|
- type_hints_of_provided_callback = get_type_hints(event_callback.fn)
|
|
|
- except NameError:
|
|
|
- type_hints_of_provided_callback = {}
|
|
|
-
|
|
|
if event_spec_return_types:
|
|
|
event_callback_spec_args = list(
|
|
|
inspect.signature(event_callback.fn).parameters.keys()
|
|
|
)
|
|
|
|
|
|
- for event_spec_index, event_spec_return_type in enumerate(
|
|
|
- event_spec_return_types
|
|
|
- ):
|
|
|
- args = get_args(event_spec_return_type)
|
|
|
-
|
|
|
- args_types_without_vars = [
|
|
|
- arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
|
|
|
- ]
|
|
|
-
|
|
|
- # check that args of event handler are matching the spec if type hints are provided
|
|
|
- for i, arg in enumerate(
|
|
|
- event_callback_spec_args[1 : len(args_types_without_vars) + 1]
|
|
|
- ):
|
|
|
- if arg not in type_hints_of_provided_callback:
|
|
|
- continue
|
|
|
-
|
|
|
- type_match_found.setdefault(arg, False)
|
|
|
-
|
|
|
- try:
|
|
|
- compare_result = typehint_issubclass(
|
|
|
- args_types_without_vars[i], type_hints_of_provided_callback[arg]
|
|
|
- )
|
|
|
- except TypeError as te:
|
|
|
- raise TypeError(
|
|
|
- f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}."
|
|
|
- ) from te
|
|
|
-
|
|
|
- if compare_result:
|
|
|
- type_match_found[arg] = True
|
|
|
- continue
|
|
|
- else:
|
|
|
- type_match_found[arg] = False
|
|
|
- delayed_exceptions.append(
|
|
|
- EventHandlerArgTypeMismatchError(
|
|
|
- f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- if all(type_match_found.values()):
|
|
|
- delayed_exceptions.clear()
|
|
|
- if event_spec_index:
|
|
|
- args = get_args(event_spec_return_types[0])
|
|
|
-
|
|
|
- args_types_without_vars = [
|
|
|
- arg if get_origin(arg) is not Var else get_args(arg)[0]
|
|
|
- for arg in args
|
|
|
- ]
|
|
|
-
|
|
|
- expect_string = ", ".join(
|
|
|
- repr(arg) for arg in args_types_without_vars
|
|
|
- ).replace("[", "\\[")
|
|
|
-
|
|
|
- given_string = ", ".join(
|
|
|
- repr(type_hints_of_provided_callback.get(arg, Any))
|
|
|
- for arg in event_callback_spec_args[1:]
|
|
|
- ).replace("[", "\\[")
|
|
|
-
|
|
|
- console.warn(
|
|
|
- f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. "
|
|
|
- f"This may lead to unexpected behavior but is intentionally ignored for {key}."
|
|
|
- )
|
|
|
- break
|
|
|
-
|
|
|
- if delayed_exceptions:
|
|
|
- raise delayed_exceptions[0]
|
|
|
+ try:
|
|
|
+ type_hints_of_provided_callback = get_type_hints(event_callback.fn)
|
|
|
+ except NameError:
|
|
|
+ type_hints_of_provided_callback = {}
|
|
|
+
|
|
|
+ _check_event_args_subclass_of_callback(
|
|
|
+ event_callback_spec_args[1:],
|
|
|
+ event_spec_return_types,
|
|
|
+ type_hints_of_provided_callback,
|
|
|
+ event_callback.fn.__qualname__,
|
|
|
+ key or "",
|
|
|
+ )
|
|
|
|
|
|
return event_callback(*event_spec_args)
|
|
|
|