|
@@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overlo
|
|
|
from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
|
|
|
|
|
|
from reflex.utils import format
|
|
|
+from reflex.utils.exceptions import VarTypeError
|
|
|
from reflex.utils.types import GenericType
|
|
|
|
|
|
from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
|
|
@@ -109,12 +110,22 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
|
|
|
Returns:
|
|
|
The partially applied function.
|
|
|
"""
|
|
|
- self._pre_check(*args)
|
|
|
if not args:
|
|
|
- return ArgsFunctionOperation.create((), self)
|
|
|
+ return self
|
|
|
+ remaining_validators = self._pre_check(*args)
|
|
|
+ if self.__call__ is self.partial:
|
|
|
+ # if the default behavior is partial, we should return a new partial function
|
|
|
+ return ArgsFunctionOperationBuilder.create(
|
|
|
+ (),
|
|
|
+ VarOperationCall.create(self, *args, Var(_js_expr="...args")),
|
|
|
+ rest="args",
|
|
|
+ validators=remaining_validators,
|
|
|
+ )
|
|
|
return ArgsFunctionOperation.create(
|
|
|
- ("...args",),
|
|
|
+ (),
|
|
|
VarOperationCall.create(self, *args, Var(_js_expr="...args")),
|
|
|
+ rest="args",
|
|
|
+ validators=remaining_validators,
|
|
|
)
|
|
|
|
|
|
@overload
|
|
@@ -187,7 +198,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
|
|
|
self._pre_check(*args)
|
|
|
return VarOperationCall.create(self, *args).guess_type()
|
|
|
|
|
|
- def _pre_check(self, *args: Var | Any) -> bool:
|
|
|
+ def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
|
|
|
"""Check if the function can be called with the given arguments.
|
|
|
|
|
|
Args:
|
|
@@ -196,7 +207,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
|
|
|
Returns:
|
|
|
True if the function can be called with the given arguments.
|
|
|
"""
|
|
|
- return True
|
|
|
+ return tuple()
|
|
|
|
|
|
__call__ = call
|
|
|
|
|
@@ -346,7 +357,8 @@ def format_args_function_operation(
|
|
|
"""
|
|
|
arg_names_str = ", ".join(
|
|
|
[arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
|
|
|
- ) + (f", ...{args.rest}" if args.rest else "")
|
|
|
+ + ([f"...{args.rest}"] if args.rest else [])
|
|
|
+ )
|
|
|
|
|
|
return_expr_str = str(LiteralVar.create(return_expr))
|
|
|
|
|
@@ -371,6 +383,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
|
default_factory=tuple
|
|
|
)
|
|
|
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
|
|
|
+ _function_name: str = dataclasses.field(default="")
|
|
|
_explicit_return: bool = dataclasses.field(default=False)
|
|
|
|
|
|
@cached_property_no_lock
|
|
@@ -384,7 +397,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
|
self._args, self._return_expr, self._explicit_return
|
|
|
)
|
|
|
|
|
|
- def _pre_check(self, *args: Var | Any) -> bool:
|
|
|
+ def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
|
|
|
"""Check if the function can be called with the given arguments.
|
|
|
|
|
|
Args:
|
|
@@ -393,10 +406,17 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
|
Returns:
|
|
|
True if the function can be called with the given arguments.
|
|
|
"""
|
|
|
- return all(
|
|
|
- validator(arg)
|
|
|
- for validator, arg in zip(self._validators, args, strict=False)
|
|
|
- )
|
|
|
+ for i, (validator, arg) in enumerate(zip(self._validators, args)):
|
|
|
+ if not validator(arg):
|
|
|
+ arg_name = self._args.args[i] if i < len(self._args.args) else None
|
|
|
+ if arg_name is not None:
|
|
|
+ raise VarTypeError(
|
|
|
+ f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
|
|
|
+ )
|
|
|
+ raise VarTypeError(
|
|
|
+ f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
|
|
|
+ )
|
|
|
+ return self._validators[len(args) :]
|
|
|
|
|
|
@classmethod
|
|
|
def create(
|
|
@@ -405,6 +425,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
|
return_expr: Var | Any,
|
|
|
rest: str | None = None,
|
|
|
validators: Sequence[Callable[[Any], bool]] = (),
|
|
|
+ function_name: str = "",
|
|
|
explicit_return: bool = False,
|
|
|
_var_type: GenericType = Callable,
|
|
|
_var_data: VarData | None = None,
|
|
@@ -415,6 +436,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
|
args_names: The names of the arguments.
|
|
|
return_expr: The return expression of the function.
|
|
|
rest: The name of the rest argument.
|
|
|
+ validators: The validators for the arguments.
|
|
|
+ function_name: The name of the function.
|
|
|
explicit_return: Whether to use explicit return syntax.
|
|
|
_var_data: Additional hooks and imports associated with the Var.
|
|
|
|
|
@@ -426,6 +449,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
|
|
_var_type=_var_type,
|
|
|
_var_data=_var_data,
|
|
|
_args=FunctionArgs(args=tuple(args_names), rest=rest),
|
|
|
+ _function_name=function_name,
|
|
|
_validators=tuple(validators),
|
|
|
_return_expr=return_expr,
|
|
|
_explicit_return=explicit_return,
|
|
@@ -441,7 +465,11 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
|
|
|
"""Base class for immutable function defined via arguments and return expression with the builder pattern."""
|
|
|
|
|
|
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
|
|
|
+ _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
|
|
|
+ default_factory=tuple
|
|
|
+ )
|
|
|
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
|
|
|
+ _function_name: str = dataclasses.field(default="")
|
|
|
_explicit_return: bool = dataclasses.field(default=False)
|
|
|
|
|
|
@cached_property_no_lock
|
|
@@ -455,12 +483,35 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
|
|
|
self._args, self._return_expr, self._explicit_return
|
|
|
)
|
|
|
|
|
|
+ def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
|
|
|
+ """Check if the function can be called with the given arguments.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ *args: The arguments to call the function with.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ True if the function can be called with the given arguments.
|
|
|
+ """
|
|
|
+ for i, (validator, arg) in enumerate(zip(self._validators, args)):
|
|
|
+ if not validator(arg):
|
|
|
+ arg_name = self._args.args[i] if i < len(self._args.args) else None
|
|
|
+ if arg_name is not None:
|
|
|
+ raise VarTypeError(
|
|
|
+ f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
|
|
|
+ )
|
|
|
+ raise VarTypeError(
|
|
|
+ f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
|
|
|
+ )
|
|
|
+ return self._validators[len(args) :]
|
|
|
+
|
|
|
@classmethod
|
|
|
def create(
|
|
|
cls,
|
|
|
args_names: Sequence[Union[str, DestructuredArg]],
|
|
|
return_expr: Var | Any,
|
|
|
rest: str | None = None,
|
|
|
+ validators: Sequence[Callable[[Any], bool]] = (),
|
|
|
+ function_name: str = "",
|
|
|
explicit_return: bool = False,
|
|
|
_var_type: GenericType = Callable,
|
|
|
_var_data: VarData | None = None,
|
|
@@ -471,6 +522,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
|
|
|
args_names: The names of the arguments.
|
|
|
return_expr: The return expression of the function.
|
|
|
rest: The name of the rest argument.
|
|
|
+ validators: The validators for the arguments.
|
|
|
+ function_name: The name of the function.
|
|
|
explicit_return: Whether to use explicit return syntax.
|
|
|
_var_data: Additional hooks and imports associated with the Var.
|
|
|
|
|
@@ -482,6 +535,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
|
|
|
_var_type=_var_type,
|
|
|
_var_data=_var_data,
|
|
|
_args=FunctionArgs(args=tuple(args_names), rest=rest),
|
|
|
+ _function_name=function_name,
|
|
|
+ _validators=tuple(validators),
|
|
|
_return_expr=return_expr,
|
|
|
_explicit_return=explicit_return,
|
|
|
)
|