Ver Fonte

fix silly mistakes

Khaleel Al-Adhami há 6 meses atrás
pai
commit
ebc81811c0
1 ficheiros alterados com 66 adições e 11 exclusões
  1. 66 11
      reflex/vars/function.py

+ 66 - 11
reflex/vars/function.py

@@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overlo
 from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
 
 from reflex.utils import format
+from reflex.utils.exceptions import VarTypeError
 from reflex.utils.types import GenericType
 
 from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
@@ -109,12 +110,22 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
         Returns:
             The partially applied function.
         """
-        self._pre_check(*args)
         if not args:
-            return ArgsFunctionOperation.create((), self)
+            return self
+        remaining_validators = self._pre_check(*args)
+        if self.__call__ is self.partial:
+            # if the default behavior is partial, we should return a new partial function
+            return ArgsFunctionOperationBuilder.create(
+                (),
+                VarOperationCall.create(self, *args, Var(_js_expr="...args")),
+                rest="args",
+                validators=remaining_validators,
+            )
         return ArgsFunctionOperation.create(
-            ("...args",),
+            (),
             VarOperationCall.create(self, *args, Var(_js_expr="...args")),
+            rest="args",
+            validators=remaining_validators,
         )
 
     @overload
@@ -187,7 +198,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
         self._pre_check(*args)
         return VarOperationCall.create(self, *args).guess_type()
 
-    def _pre_check(self, *args: Var | Any) -> bool:
+    def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
         """Check if the function can be called with the given arguments.
 
         Args:
@@ -196,7 +207,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
         Returns:
             True if the function can be called with the given arguments.
         """
-        return True
+        return tuple()
 
     __call__ = call
 
@@ -346,7 +357,8 @@ def format_args_function_operation(
     """
     arg_names_str = ", ".join(
         [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
-    ) + (f", ...{args.rest}" if args.rest else "")
+        + ([f"...{args.rest}"] if args.rest else [])
+    )
 
     return_expr_str = str(LiteralVar.create(return_expr))
 
@@ -371,6 +383,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         default_factory=tuple
     )
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
+    _function_name: str = dataclasses.field(default="")
     _explicit_return: bool = dataclasses.field(default=False)
 
     @cached_property_no_lock
@@ -384,7 +397,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             self._args, self._return_expr, self._explicit_return
         )
 
-    def _pre_check(self, *args: Var | Any) -> bool:
+    def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
         """Check if the function can be called with the given arguments.
 
         Args:
@@ -393,10 +406,17 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         Returns:
             True if the function can be called with the given arguments.
         """
-        return all(
-            validator(arg)
-            for validator, arg in zip(self._validators, args, strict=False)
-        )
+        for i, (validator, arg) in enumerate(zip(self._validators, args)):
+            if not validator(arg):
+                arg_name = self._args.args[i] if i < len(self._args.args) else None
+                if arg_name is not None:
+                    raise VarTypeError(
+                        f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
+                    )
+                raise VarTypeError(
+                    f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
+                )
+        return self._validators[len(args) :]
 
     @classmethod
     def create(
@@ -405,6 +425,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         return_expr: Var | Any,
         rest: str | None = None,
         validators: Sequence[Callable[[Any], bool]] = (),
+        function_name: str = "",
         explicit_return: bool = False,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
@@ -415,6 +436,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             args_names: The names of the arguments.
             return_expr: The return expression of the function.
             rest: The name of the rest argument.
+            validators: The validators for the arguments.
+            function_name: The name of the function.
             explicit_return: Whether to use explicit return syntax.
             _var_data: Additional hooks and imports associated with the Var.
 
@@ -426,6 +449,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             _var_type=_var_type,
             _var_data=_var_data,
             _args=FunctionArgs(args=tuple(args_names), rest=rest),
+            _function_name=function_name,
             _validators=tuple(validators),
             _return_expr=return_expr,
             _explicit_return=explicit_return,
@@ -441,7 +465,11 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
     """Base class for immutable function defined via arguments and return expression with the builder pattern."""
 
     _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
+    _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
+        default_factory=tuple
+    )
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
+    _function_name: str = dataclasses.field(default="")
     _explicit_return: bool = dataclasses.field(default=False)
 
     @cached_property_no_lock
@@ -455,12 +483,35 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
             self._args, self._return_expr, self._explicit_return
         )
 
+    def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
+        """Check if the function can be called with the given arguments.
+
+        Args:
+            *args: The arguments to call the function with.
+
+        Returns:
+            True if the function can be called with the given arguments.
+        """
+        for i, (validator, arg) in enumerate(zip(self._validators, args)):
+            if not validator(arg):
+                arg_name = self._args.args[i] if i < len(self._args.args) else None
+                if arg_name is not None:
+                    raise VarTypeError(
+                        f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
+                    )
+                raise VarTypeError(
+                    f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
+                )
+        return self._validators[len(args) :]
+
     @classmethod
     def create(
         cls,
         args_names: Sequence[Union[str, DestructuredArg]],
         return_expr: Var | Any,
         rest: str | None = None,
+        validators: Sequence[Callable[[Any], bool]] = (),
+        function_name: str = "",
         explicit_return: bool = False,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
@@ -471,6 +522,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
             args_names: The names of the arguments.
             return_expr: The return expression of the function.
             rest: The name of the rest argument.
+            validators: The validators for the arguments.
+            function_name: The name of the function.
             explicit_return: Whether to use explicit return syntax.
             _var_data: Additional hooks and imports associated with the Var.
 
@@ -482,6 +535,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
             _var_type=_var_type,
             _var_data=_var_data,
             _args=FunctionArgs(args=tuple(args_names), rest=rest),
+            _function_name=function_name,
+            _validators=tuple(validators),
             _return_expr=return_expr,
             _explicit_return=explicit_return,
         )