|
@@ -109,6 +109,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
|
|
|
Returns:
|
|
|
The partially applied function.
|
|
|
"""
|
|
|
+ self._pre_check(*args)
|
|
|
if not args:
|
|
|
return ArgsFunctionOperation.create((), self)
|
|
|
return ArgsFunctionOperation.create(
|
|
@@ -183,8 +184,20 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
|
|
|
Returns:
|
|
|
The function call operation.
|
|
|
"""
|
|
|
+ self._pre_check(*args)
|
|
|
return VarOperationCall.create(self, *args).guess_type()
|
|
|
|
|
|
+ def _pre_check(self, *args: Var | Any) -> bool:
|
|
|
+ """Check if the function can be called with the given arguments.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ *args: The arguments to call the function with.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ True if the function can be called with the given arguments.
|
|
|
+ """
|
|
|
+ return True
|
|
|
+
|
|
|
__call__ = call
|
|
|
|
|
|
|
|
@@ -354,6 +367,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
|
"""Base class for immutable function defined via arguments and return expression."""
|
|
|
|
|
|
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
|
|
|
+ _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
|
|
|
+ default_factory=tuple
|
|
|
+ )
|
|
|
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
|
|
|
_explicit_return: bool = dataclasses.field(default=False)
|
|
|
|
|
@@ -368,12 +384,27 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
|
self._args, self._return_expr, self._explicit_return
|
|
|
)
|
|
|
|
|
|
+ def _pre_check(self, *args: Var | Any) -> bool:
|
|
|
+ """Check if the function can be called with the given arguments.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ *args: The arguments to call the function with.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ True if the function can be called with the given arguments.
|
|
|
+ """
|
|
|
+ return all(
|
|
|
+ validator(arg)
|
|
|
+ for validator, arg in zip(self._validators, args, strict=False)
|
|
|
+ )
|
|
|
+
|
|
|
@classmethod
|
|
|
def create(
|
|
|
cls,
|
|
|
args_names: Sequence[Union[str, DestructuredArg]],
|
|
|
return_expr: Var | Any,
|
|
|
rest: str | None = None,
|
|
|
+ validators: Sequence[Callable[[Any], bool]] = (),
|
|
|
explicit_return: bool = False,
|
|
|
_var_type: GenericType = Callable,
|
|
|
_var_data: VarData | None = None,
|
|
@@ -395,6 +426,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
|
_var_type=_var_type,
|
|
|
_var_data=_var_data,
|
|
|
_args=FunctionArgs(args=tuple(args_names), rest=rest),
|
|
|
+ _validators=tuple(validators),
|
|
|
_return_expr=return_expr,
|
|
|
_explicit_return=explicit_return,
|
|
|
)
|