Răsfoiți Sursa

add validation

Khaleel Al-Adhami 6 luni în urmă
părinte
comite
05bd41c040
1 a modificat fișierele cu 32 adăugiri și 0 ștergeri
  1. 32 0
      reflex/vars/function.py

+ 32 - 0
reflex/vars/function.py

@@ -109,6 +109,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
         Returns:
             The partially applied function.
         """
+        self._pre_check(*args)
         if not args:
             return ArgsFunctionOperation.create((), self)
         return ArgsFunctionOperation.create(
@@ -183,8 +184,20 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
         Returns:
             The function call operation.
         """
+        self._pre_check(*args)
         return VarOperationCall.create(self, *args).guess_type()
 
+    def _pre_check(self, *args: Var | Any) -> bool:
+        """Check if the function can be called with the given arguments.
+
+        Args:
+            *args: The arguments to call the function with.
+
+        Returns:
+            True if the function can be called with the given arguments.
+        """
+        return True
+
     __call__ = call
 
 
@@ -354,6 +367,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
     """Base class for immutable function defined via arguments and return expression."""
 
     _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
+    _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
+        default_factory=tuple
+    )
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
     _explicit_return: bool = dataclasses.field(default=False)
 
@@ -368,12 +384,27 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             self._args, self._return_expr, self._explicit_return
         )
 
+    def _pre_check(self, *args: Var | Any) -> bool:
+        """Check if the function can be called with the given arguments.
+
+        Args:
+            *args: The arguments to call the function with.
+
+        Returns:
+            True if the function can be called with the given arguments.
+        """
+        return all(
+            validator(arg)
+            for validator, arg in zip(self._validators, args, strict=False)
+        )
+
     @classmethod
     def create(
         cls,
         args_names: Sequence[Union[str, DestructuredArg]],
         return_expr: Var | Any,
         rest: str | None = None,
+        validators: Sequence[Callable[[Any], bool]] = (),
         explicit_return: bool = False,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
@@ -395,6 +426,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             _var_type=_var_type,
             _var_data=_var_data,
             _args=FunctionArgs(args=tuple(args_names), rest=rest),
+            _validators=tuple(validators),
             _return_expr=return_expr,
             _explicit_return=explicit_return,
         )