Browse Source

implement type computers

Khaleel Al-Adhami 6 months ago
parent
commit
f4aa1f58c3
6 changed files with 699 additions and 317 deletions
  1. 325 115
      reflex/vars/base.py
  2. 155 77
      reflex/vars/function.py
  3. 62 49
      reflex/vars/number.py
  4. 31 12
      reflex/vars/object.py
  5. 124 62
      reflex/vars/sequence.py
  6. 2 2
      tests/units/test_var.py

+ 325 - 115
reflex/vars/base.py

@@ -14,7 +14,7 @@ import re
 import string
 import string
 import sys
 import sys
 import warnings
 import warnings
-from types import CodeType, FunctionType
+from types import CodeType, EllipsisType, FunctionType
 from typing import (
 from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
     Any,
     Any,
@@ -26,7 +26,6 @@ from typing import (
     Iterable,
     Iterable,
     List,
     List,
     Literal,
     Literal,
-    NoReturn,
     Optional,
     Optional,
     Set,
     Set,
     Tuple,
     Tuple,
@@ -38,7 +37,14 @@ from typing import (
     overload,
     overload,
 )
 )
 
 
-from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
+from typing_extensions import (
+    ParamSpec,
+    Protocol,
+    TypeGuard,
+    deprecated,
+    get_type_hints,
+    override,
+)
 
 
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
@@ -69,6 +75,7 @@ from reflex.utils.types import (
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.state import BaseState
     from reflex.state import BaseState
 
 
+    from .function import ArgsFunctionOperation, ReflexCallable
     from .number import BooleanVar, NumberVar
     from .number import BooleanVar, NumberVar
     from .object import ObjectVar
     from .object import ObjectVar
     from .sequence import ArrayVar, StringVar
     from .sequence import ArrayVar, StringVar
@@ -79,6 +86,36 @@ OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
 
 
 warnings.filterwarnings("ignore", message="fields may not start with an underscore")
 warnings.filterwarnings("ignore", message="fields may not start with an underscore")
 
 
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class ReflexCallable(Protocol[P, R]):
+    """Protocol for a callable."""
+
+    __call__: Callable[P, R]
+
+
+def unwrap_reflex_callalbe(
+    callable_type: GenericType,
+) -> Tuple[Union[EllipsisType, Tuple[GenericType, ...]], GenericType]:
+    """Unwrap the ReflexCallable type.
+
+    Args:
+        callable_type: The ReflexCallable type to unwrap.
+
+    Returns:
+        The unwrapped ReflexCallable type.
+    """
+    if callable_type is ReflexCallable:
+        return Ellipsis, Any
+    if get_origin(callable_type) is not ReflexCallable:
+        return Ellipsis, Any
+    args = get_args(callable_type)
+    if not args or len(args) != 2:
+        return Ellipsis, Any
+    return args
+
 
 
 @dataclasses.dataclass(
 @dataclasses.dataclass(
     eq=False,
     eq=False,
@@ -409,9 +446,11 @@ class Var(Generic[VAR_TYPE]):
 
 
         if _var_data or _js_expr != self._js_expr:
         if _var_data or _js_expr != self._js_expr:
             self.__init__(
             self.__init__(
-                _js_expr=_js_expr,
-                _var_type=self._var_type,
-                _var_data=VarData.merge(self._var_data, _var_data),
+                **{
+                    **dataclasses.asdict(self),
+                    "_js_expr": _js_expr,
+                    "_var_data": VarData.merge(self._var_data, _var_data),
+                }
             )
             )
 
 
     def __hash__(self) -> int:
     def __hash__(self) -> int:
@@ -690,6 +729,12 @@ class Var(Generic[VAR_TYPE]):
     @overload
     @overload
     def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
     def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
 
 
+    @overload
+    def guess_type(self: Var[list] | Var[tuple] | Var[set]) -> ArrayVar: ...
+
+    @overload
+    def guess_type(self: Var[dict]) -> ObjectVar[dict]: ...
+
     @overload
     @overload
     def guess_type(self) -> Self: ...
     def guess_type(self) -> Self: ...
 
 
@@ -1413,71 +1458,94 @@ def get_python_literal(value: Union[LiteralVar, Any]) -> Any | None:
     return value
     return value
 
 
 
 
+def validate_arg(type_hint: GenericType) -> Callable[[Any], bool]:
+    """Create a validator for an argument.
+
+    Args:
+        type_hint: The type hint of the argument.
+
+    Returns:
+        The validator.
+    """
+
+    def validate(value: Any):
+        return True
+
+    return validate
+
+
 P = ParamSpec("P")
 P = ParamSpec("P")
 T = TypeVar("T")
 T = TypeVar("T")
+V1 = TypeVar("V1")
+V2 = TypeVar("V2")
+V3 = TypeVar("V3")
+V4 = TypeVar("V4")
+V5 = TypeVar("V5")
 
 
 
 
-# NoReturn is used to match CustomVarOperationReturn with no type hint.
-@overload
-def var_operation(
-    func: Callable[P, CustomVarOperationReturn[NoReturn]],
-) -> Callable[P, Var]: ...
-
+class TypeComputer(Protocol):
+    """A protocol for type computers."""
 
 
-@overload
-def var_operation(
-    func: Callable[P, CustomVarOperationReturn[bool]],
-) -> Callable[P, BooleanVar]: ...
+    def __call__(self, *args: Var) -> Tuple[GenericType, Union[TypeComputer, None]]:
+        """Compute the type of the operation.
 
 
+        Args:
+            *args: The arguments to compute the type of.
 
 
-NUMBER_T = TypeVar("NUMBER_T", int, float, Union[int, float])
+        Returns:
+            The type of the operation.
+        """
+        ...
 
 
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[NUMBER_T]],
-) -> Callable[P, NumberVar[NUMBER_T]]: ...
+    func: Callable[[], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[], T]]: ...
 
 
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[str]],
-) -> Callable[P, StringVar]: ...
-
-
-LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set])
+    func: Callable[[Var[V1]], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[V1], T]]: ...
 
 
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[LIST_T]],
-) -> Callable[P, ArrayVar[LIST_T]]: ...
+    func: Callable[[Var[V1], Var[V2]], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[V1, V2], T]]: ...
 
 
 
 
-OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
+@overload
+def var_operation(
+    func: Callable[[Var[V1], Var[V2], Var[V3]], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3], T]]: ...
 
 
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]],
-) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...
+    func: Callable[[Var[V1], Var[V2], Var[V3], Var[V4]], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4], T]]: ...
 
 
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[T]],
-) -> Callable[P, Var[T]]: ...
+    func: Callable[
+        [Var[V1], Var[V2], Var[V3], Var[V4], Var[V5]],
+        CustomVarOperationReturn[T],
+    ],
+) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4, V5], T]]: ...
 
 
 
 
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[T]],
-) -> Callable[P, Var[T]]:
+    func: Callable[..., CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation:
     """Decorator for creating a var operation.
     """Decorator for creating a var operation.
 
 
     Example:
     Example:
     ```python
     ```python
     @var_operation
     @var_operation
-    def add(a: NumberVar, b: NumberVar):
+    def add(a: Var[int], b: Var[int]):
         return custom_var_operation(f"{a} + {b}")
         return custom_var_operation(f"{a} + {b}")
     ```
     ```
 
 
@@ -1487,26 +1555,61 @@ def var_operation(
     Returns:
     Returns:
         The decorated function.
         The decorated function.
     """
     """
+    from .function import ArgsFunctionOperation, ReflexCallable
 
 
-    @functools.wraps(func)
-    def wrapper(*args: P.args, **kwargs: P.kwargs) -> Var[T]:
-        func_args = list(inspect.signature(func).parameters)
-        args_vars = {
-            func_args[i]: (LiteralVar.create(arg) if not isinstance(arg, Var) else arg)
-            for i, arg in enumerate(args)
-        }
-        kwargs_vars = {
-            key: LiteralVar.create(value) if not isinstance(value, Var) else value
-            for key, value in kwargs.items()
-        }
-
-        return CustomVarOperation.create(
-            name=func.__name__,
-            args=tuple(list(args_vars.items()) + list(kwargs_vars.items())),
-            return_var=func(*args_vars.values(), **kwargs_vars),  # type: ignore
-        ).guess_type()
+    func_name = func.__name__
 
 
-    return wrapper
+    func_arg_spec = inspect.getfullargspec(func)
+
+    if func_arg_spec.kwonlyargs:
+        raise TypeError(f"Function {func_name} cannot have keyword-only arguments.")
+    if func_arg_spec.varargs:
+        raise TypeError(f"Function {func_name} cannot have variable arguments.")
+
+    arg_names = func_arg_spec.args
+
+    type_hints = get_type_hints(func)
+
+    if not all(
+        (get_origin((type_hint := type_hints.get(arg_name, Any))) or type_hint) is Var
+        and len(get_args(type_hint)) <= 1
+        for arg_name in arg_names
+    ):
+        raise TypeError(
+            f"Function {func_name} must have type hints of the form `Var[Type]`."
+        )
+
+    args_with_type_hints = tuple(
+        (arg_name, (args[0] if (args := get_args(type_hints[arg_name])) else Any))
+        for arg_name in arg_names
+    )
+
+    arg_vars = tuple(
+        (
+            Var("_" + arg_name, _var_type=arg_python_type)
+            if not isinstance(arg_python_type, TypeVar)
+            else Var("_" + arg_name)
+        )
+        for arg_name, arg_python_type in args_with_type_hints
+    )
+
+    custom_operation_return = func(*arg_vars)
+
+    args_operation = ArgsFunctionOperation.create(
+        tuple(map(str, arg_vars)),
+        custom_operation_return,
+        validators=tuple(
+            validate_arg(type_hints.get(arg_name, Any)) for arg_name in arg_names
+        ),
+        function_name=func_name,
+        type_computer=custom_operation_return._type_computer,
+        _var_type=ReflexCallable[
+            tuple(arg_python_type for _, arg_python_type in args_with_type_hints),
+            custom_operation_return._var_type,
+        ],
+    )
+
+    return args_operation
 
 
 
 
 def figure_out_type(value: Any) -> types.GenericType:
 def figure_out_type(value: Any) -> types.GenericType:
@@ -1621,66 +1724,6 @@ class CachedVarOperation:
         )
         )
 
 
 
 
-def and_operation(a: Var | Any, b: Var | Any) -> Var:
-    """Perform a logical AND operation on two variables.
-
-    Args:
-        a: The first variable.
-        b: The second variable.
-
-    Returns:
-        The result of the logical AND operation.
-    """
-    return _and_operation(a, b)  # type: ignore
-
-
-@var_operation
-def _and_operation(a: Var, b: Var):
-    """Perform a logical AND operation on two variables.
-
-    Args:
-        a: The first variable.
-        b: The second variable.
-
-    Returns:
-        The result of the logical AND operation.
-    """
-    return var_operation_return(
-        js_expression=f"({a} && {b})",
-        var_type=unionize(a._var_type, b._var_type),
-    )
-
-
-def or_operation(a: Var | Any, b: Var | Any) -> Var:
-    """Perform a logical OR operation on two variables.
-
-    Args:
-        a: The first variable.
-        b: The second variable.
-
-    Returns:
-        The result of the logical OR operation.
-    """
-    return _or_operation(a, b)  # type: ignore
-
-
-@var_operation
-def _or_operation(a: Var, b: Var):
-    """Perform a logical OR operation on two variables.
-
-    Args:
-        a: The first variable.
-        b: The second variable.
-
-    Returns:
-        The result of the logical OR operation.
-    """
-    return var_operation_return(
-        js_expression=f"({a} || {b})",
-        var_type=unionize(a._var_type, b._var_type),
-    )
-
-
 @dataclasses.dataclass(
 @dataclasses.dataclass(
     eq=False,
     eq=False,
     frozen=True,
     frozen=True,
@@ -2289,14 +2332,22 @@ def computed_var(
 RETURN = TypeVar("RETURN")
 RETURN = TypeVar("RETURN")
 
 
 
 
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
 class CustomVarOperationReturn(Var[RETURN]):
 class CustomVarOperationReturn(Var[RETURN]):
     """Base class for custom var operations."""
     """Base class for custom var operations."""
 
 
+    _type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
+
     @classmethod
     @classmethod
     def create(
     def create(
         cls,
         cls,
         js_expression: str,
         js_expression: str,
         _var_type: Type[RETURN] | None = None,
         _var_type: Type[RETURN] | None = None,
+        _type_computer: Optional[TypeComputer] = None,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
     ) -> CustomVarOperationReturn[RETURN]:
     ) -> CustomVarOperationReturn[RETURN]:
         """Create a CustomVarOperation.
         """Create a CustomVarOperation.
@@ -2304,6 +2355,7 @@ class CustomVarOperationReturn(Var[RETURN]):
         Args:
         Args:
             js_expression: The JavaScript expression to evaluate.
             js_expression: The JavaScript expression to evaluate.
             _var_type: The type of the var.
             _var_type: The type of the var.
+            _type_computer: A function to compute the type of the var given the arguments.
             _var_data: Additional hooks and imports associated with the Var.
             _var_data: Additional hooks and imports associated with the Var.
 
 
         Returns:
         Returns:
@@ -2312,6 +2364,7 @@ class CustomVarOperationReturn(Var[RETURN]):
         return CustomVarOperationReturn(
         return CustomVarOperationReturn(
             _js_expr=js_expression,
             _js_expr=js_expression,
             _var_type=_var_type or Any,
             _var_type=_var_type or Any,
+            _type_computer=_type_computer,
             _var_data=_var_data,
             _var_data=_var_data,
         )
         )
 
 
@@ -2319,6 +2372,7 @@ class CustomVarOperationReturn(Var[RETURN]):
 def var_operation_return(
 def var_operation_return(
     js_expression: str,
     js_expression: str,
     var_type: Type[RETURN] | None = None,
     var_type: Type[RETURN] | None = None,
+    type_computer: Optional[TypeComputer] = None,
     var_data: VarData | None = None,
     var_data: VarData | None = None,
 ) -> CustomVarOperationReturn[RETURN]:
 ) -> CustomVarOperationReturn[RETURN]:
     """Shortcut for creating a CustomVarOperationReturn.
     """Shortcut for creating a CustomVarOperationReturn.
@@ -2326,15 +2380,17 @@ def var_operation_return(
     Args:
     Args:
         js_expression: The JavaScript expression to evaluate.
         js_expression: The JavaScript expression to evaluate.
         var_type: The type of the var.
         var_type: The type of the var.
+        type_computer: A function to compute the type of the var given the arguments.
         var_data: Additional hooks and imports associated with the Var.
         var_data: Additional hooks and imports associated with the Var.
 
 
     Returns:
     Returns:
         The CustomVarOperationReturn.
         The CustomVarOperationReturn.
     """
     """
     return CustomVarOperationReturn.create(
     return CustomVarOperationReturn.create(
-        js_expression,
-        var_type,
-        var_data,
+        js_expression=js_expression,
+        _var_type=var_type,
+        _type_computer=type_computer,
+        _var_data=var_data,
     )
     )
 
 
 
 
@@ -2942,3 +2998,157 @@ def field(value: T) -> Field[T]:
         The Field.
         The Field.
     """
     """
     return value  # type: ignore
     return value  # type: ignore
+
+
+def and_operation(a: Var | Any, b: Var | Any) -> Var:
+    """Perform a logical AND operation on two variables.
+
+    Args:
+        a: The first variable.
+        b: The second variable.
+
+    Returns:
+        The result of the logical AND operation.
+    """
+    return _and_operation(a, b)  # type: ignore
+
+
+@var_operation
+def _and_operation(a: Var, b: Var):
+    """Perform a logical AND operation on two variables.
+
+    Args:
+        a: The first variable.
+        b: The second variable.
+
+    Returns:
+        The result of the logical AND operation.
+    """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (ReflexCallable[[Any, Any], Any], type_computer)
+        if len(args) == 1:
+            return (
+                ReflexCallable[[Any], Any],
+                functools.partial(type_computer, args[0]),
+            )
+        return (
+            ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
+            None,
+        )
+
+    return var_operation_return(
+        js_expression=f"({a} && {b})",
+        type_computer=type_computer,
+    )
+
+
+def or_operation(a: Var | Any, b: Var | Any) -> Var:
+    """Perform a logical OR operation on two variables.
+
+    Args:
+        a: The first variable.
+        b: The second variable.
+
+    Returns:
+        The result of the logical OR operation.
+    """
+    return _or_operation(a, b)  # type: ignore
+
+
+@var_operation
+def _or_operation(a: Var, b: Var):
+    """Perform a logical OR operation on two variables.
+
+    Args:
+        a: The first variable.
+        b: The second variable.
+
+    Returns:
+        The result of the logical OR operation.
+    """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (ReflexCallable[[Any, Any], Any], type_computer)
+        if len(args) == 1:
+            return (
+                ReflexCallable[[Any], Any],
+                functools.partial(type_computer, args[0]),
+            )
+        return (
+            ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
+            None,
+        )
+
+    return var_operation_return(
+        js_expression=f"({a} || {b})",
+        type_computer=type_computer,
+    )
+
+
+def passthrough_unary_type_computer(no_args: GenericType) -> TypeComputer:
+    """Create a type computer for unary operations.
+
+    Args:
+        no_args: The type to return when no arguments are provided.
+
+    Returns:
+        The type computer.
+    """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (no_args, type_computer)
+        return (ReflexCallable[[], args[0]._var_type], None)
+
+    return type_computer
+
+
+def unary_type_computer(
+    no_args: GenericType, computer: Callable[[Var], GenericType]
+) -> TypeComputer:
+    """Create a type computer for unary operations.
+
+    Args:
+        no_args: The type to return when no arguments are provided.
+        computer: The function to compute the type.
+
+    Returns:
+        The type computer.
+    """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (no_args, type_computer)
+        return (ReflexCallable[[], computer(args[0])], None)
+
+    return type_computer
+
+
+def nary_type_computer(
+    *types: GenericType, computer: Callable[..., GenericType]
+) -> TypeComputer:
+    """Create a type computer for n-ary operations.
+
+    Args:
+        types: The types to return when no arguments are provided.
+        computer: The function to compute the type.
+
+    Returns:
+        The type computer.
+    """
+
+    def type_computer(*args: Var):
+        if len(args) != len(types):
+            return (
+                ReflexCallable[[], types[len(args)]],
+                functools.partial(type_computer, *args),
+            )
+        return (
+            ReflexCallable[[], computer(args)],
+            None,
+        )
+
+    return type_computer

+ 155 - 77
reflex/vars/function.py

@@ -6,28 +6,31 @@ import dataclasses
 import sys
 import sys
 from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
 from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
 
 
-from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
+from typing_extensions import Concatenate, Generic, ParamSpec, TypeVar
 
 
 from reflex.utils import format
 from reflex.utils import format
 from reflex.utils.exceptions import VarTypeError
 from reflex.utils.exceptions import VarTypeError
 from reflex.utils.types import GenericType
 from reflex.utils.types import GenericType
 
 
-from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
+from .base import (
+    CachedVarOperation,
+    LiteralVar,
+    ReflexCallable,
+    TypeComputer,
+    Var,
+    VarData,
+    cached_property_no_lock,
+    unwrap_reflex_callalbe,
+)
 
 
 P = ParamSpec("P")
 P = ParamSpec("P")
+R = TypeVar("R")
 V1 = TypeVar("V1")
 V1 = TypeVar("V1")
 V2 = TypeVar("V2")
 V2 = TypeVar("V2")
 V3 = TypeVar("V3")
 V3 = TypeVar("V3")
 V4 = TypeVar("V4")
 V4 = TypeVar("V4")
 V5 = TypeVar("V5")
 V5 = TypeVar("V5")
 V6 = TypeVar("V6")
 V6 = TypeVar("V6")
-R = TypeVar("R")
-
-
-class ReflexCallable(Protocol[P, R]):
-    """Protocol for a callable."""
-
-    __call__: Callable[P, R]
 
 
 
 
 CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
 CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
@@ -112,20 +115,37 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
         """
         """
         if not args:
         if not args:
             return self
             return self
+
+        args = tuple(map(LiteralVar.create, args))
+
         remaining_validators = self._pre_check(*args)
         remaining_validators = self._pre_check(*args)
+
+        partial_types, type_computer = self._partial_type(*args)
+
         if self.__call__ is self.partial:
         if self.__call__ is self.partial:
             # if the default behavior is partial, we should return a new partial function
             # if the default behavior is partial, we should return a new partial function
             return ArgsFunctionOperationBuilder.create(
             return ArgsFunctionOperationBuilder.create(
                 (),
                 (),
-                VarOperationCall.create(self, *args, Var(_js_expr="...args")),
+                VarOperationCall.create(
+                    self,
+                    *args,
+                    Var(_js_expr="...args"),
+                    _var_type=self._return_type(*args),
+                ),
                 rest="args",
                 rest="args",
                 validators=remaining_validators,
                 validators=remaining_validators,
+                type_computer=type_computer,
+                _var_type=partial_types,
             )
             )
         return ArgsFunctionOperation.create(
         return ArgsFunctionOperation.create(
             (),
             (),
-            VarOperationCall.create(self, *args, Var(_js_expr="...args")),
+            VarOperationCall.create(
+                self, *args, Var(_js_expr="...args"), _var_type=self._return_type(*args)
+            ),
             rest="args",
             rest="args",
             validators=remaining_validators,
             validators=remaining_validators,
+            type_computer=type_computer,
+            _var_type=partial_types,
         )
         )
 
 
     @overload
     @overload
@@ -194,9 +214,56 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
 
 
         Returns:
         Returns:
             The function call operation.
             The function call operation.
+
+        Raises:
+            VarTypeError: If the number of arguments is invalid
         """
         """
+        arg_len = self._arg_len()
+        if arg_len is not None and len(args) != arg_len:
+            raise VarTypeError(f"Invalid number of arguments provided to {str(self)}")
+        args = tuple(map(LiteralVar.create, args))
         self._pre_check(*args)
         self._pre_check(*args)
-        return VarOperationCall.create(self, *args).guess_type()
+        return_type = self._return_type(*args)
+        return VarOperationCall.create(self, *args, _var_type=return_type).guess_type()
+
+    def _partial_type(
+        self, *args: Var | Any
+    ) -> Tuple[GenericType, Optional[TypeComputer]]:
+        """Override the type of the function call with the given arguments.
+
+        Args:
+            *args: The arguments to call the function with.
+
+        Returns:
+            The overridden type of the function call.
+        """
+        args_types, return_type = unwrap_reflex_callalbe(self._var_type)
+        if isinstance(args_types, tuple):
+            return ReflexCallable[[*args_types[len(args) :]], return_type], None
+        return ReflexCallable[..., return_type], None
+
+    def _arg_len(self) -> int | None:
+        """Get the number of arguments the function takes.
+
+        Returns:
+            The number of arguments the function takes.
+        """
+        args_types, _ = unwrap_reflex_callalbe(self._var_type)
+        if isinstance(args_types, tuple):
+            return len(args_types)
+        return None
+
+    def _return_type(self, *args: Var | Any) -> GenericType:
+        """Override the type of the function call with the given arguments.
+
+        Args:
+            *args: The arguments to call the function with.
+
+        Returns:
+            The overridden type of the function call.
+        """
+        partial_types, _ = self._partial_type(*args)
+        return unwrap_reflex_callalbe(partial_types)[1]
 
 
     def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
     def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
         """Check if the function can be called with the given arguments.
         """Check if the function can be called with the given arguments.
@@ -343,11 +410,12 @@ class FunctionArgs:
 
 
 
 
 def format_args_function_operation(
 def format_args_function_operation(
-    args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
+    self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
 ) -> str:
 ) -> str:
     """Format an args function operation.
     """Format an args function operation.
 
 
     Args:
     Args:
+        self: The function operation.
         args: The function arguments.
         args: The function arguments.
         return_expr: The return expression.
         return_expr: The return expression.
         explicit_return: Whether to use explicit return syntax.
         explicit_return: Whether to use explicit return syntax.
@@ -356,26 +424,76 @@ def format_args_function_operation(
         The formatted args function operation.
         The formatted args function operation.
     """
     """
     arg_names_str = ", ".join(
     arg_names_str = ", ".join(
-        [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
-        + ([f"...{args.rest}"] if args.rest else [])
+        [
+            arg if isinstance(arg, str) else arg.to_javascript()
+            for arg in self._args.args
+        ]
+        + ([f"...{self._args.rest}"] if self._args.rest else [])
     )
     )
 
 
-    return_expr_str = str(LiteralVar.create(return_expr))
+    return_expr_str = str(LiteralVar.create(self._return_expr))
 
 
     # Wrap return expression in curly braces if explicit return syntax is used.
     # Wrap return expression in curly braces if explicit return syntax is used.
     return_expr_str_wrapped = (
     return_expr_str_wrapped = (
-        format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
+        format.wrap(return_expr_str, "{", "}")
+        if self._explicit_return
+        else return_expr_str
     )
     )
 
 
     return f"(({arg_names_str}) => {return_expr_str_wrapped})"
     return f"(({arg_names_str}) => {return_expr_str_wrapped})"
 
 
 
 
+def pre_check_args(
+    self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, *args: Var | Any
+) -> Tuple[Callable[[Any], bool], ...]:
+    """Check if the function can be called with the given arguments.
+
+    Args:
+        self: The function operation.
+        *args: The arguments to call the function with.
+
+    Returns:
+        True if the function can be called with the given arguments.
+    """
+    for i, (validator, arg) in enumerate(zip(self._validators, args)):
+        if not validator(arg):
+            arg_name = self._args.args[i] if i < len(self._args.args) else None
+            if arg_name is not None:
+                raise VarTypeError(
+                    f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
+                )
+            raise VarTypeError(
+                f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
+            )
+    return self._validators[len(args) :]
+
+
+def figure_partial_type(
+    self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
+    *args: Var | Any,
+) -> Tuple[GenericType, Optional[TypeComputer]]:
+    """Figure out the return type of the function.
+
+    Args:
+        self: The function operation.
+        *args: The arguments to call the function with.
+
+    Returns:
+        The return type of the function.
+    """
+    return (
+        self._type_computer(*args)
+        if self._type_computer is not None
+        else FunctionVar._partial_type(self, *args)
+    )
+
+
 @dataclasses.dataclass(
 @dataclasses.dataclass(
     eq=False,
     eq=False,
     frozen=True,
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
 )
-class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
+class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]):
     """Base class for immutable function defined via arguments and return expression."""
     """Base class for immutable function defined via arguments and return expression."""
 
 
     _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
     _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
@@ -384,39 +502,14 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
     )
     )
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
     _function_name: str = dataclasses.field(default="")
     _function_name: str = dataclasses.field(default="")
+    _type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
     _explicit_return: bool = dataclasses.field(default=False)
     _explicit_return: bool = dataclasses.field(default=False)
 
 
-    @cached_property_no_lock
-    def _cached_var_name(self) -> str:
-        """The name of the var.
+    _cached_var_name = cached_property_no_lock(format_args_function_operation)
 
 
-        Returns:
-            The name of the var.
-        """
-        return format_args_function_operation(
-            self._args, self._return_expr, self._explicit_return
-        )
-
-    def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
-        """Check if the function can be called with the given arguments.
-
-        Args:
-            *args: The arguments to call the function with.
+    _pre_check = pre_check_args
 
 
-        Returns:
-            True if the function can be called with the given arguments.
-        """
-        for i, (validator, arg) in enumerate(zip(self._validators, args)):
-            if not validator(arg):
-                arg_name = self._args.args[i] if i < len(self._args.args) else None
-                if arg_name is not None:
-                    raise VarTypeError(
-                        f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
-                    )
-                raise VarTypeError(
-                    f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
-                )
-        return self._validators[len(args) :]
+    _partial_type = figure_partial_type
 
 
     @classmethod
     @classmethod
     def create(
     def create(
@@ -427,6 +520,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         validators: Sequence[Callable[[Any], bool]] = (),
         validators: Sequence[Callable[[Any], bool]] = (),
         function_name: str = "",
         function_name: str = "",
         explicit_return: bool = False,
         explicit_return: bool = False,
+        type_computer: Optional[TypeComputer] = None,
         _var_type: GenericType = Callable,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
     ):
     ):
@@ -439,6 +533,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             validators: The validators for the arguments.
             validators: The validators for the arguments.
             function_name: The name of the function.
             function_name: The name of the function.
             explicit_return: Whether to use explicit return syntax.
             explicit_return: Whether to use explicit return syntax.
+            type_computer: A function to compute the return type.
+            _var_type: The type of the var.
             _var_data: Additional hooks and imports associated with the Var.
             _var_data: Additional hooks and imports associated with the Var.
 
 
         Returns:
         Returns:
@@ -453,6 +549,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             _validators=tuple(validators),
             _validators=tuple(validators),
             _return_expr=return_expr,
             _return_expr=return_expr,
             _explicit_return=explicit_return,
             _explicit_return=explicit_return,
+            _type_computer=type_computer,
         )
         )
 
 
 
 
@@ -461,7 +558,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
     frozen=True,
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
 )
-class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
+class ArgsFunctionOperationBuilder(
+    CachedVarOperation, BuilderFunctionVar[CALLABLE_TYPE]
+):
     """Base class for immutable function defined via arguments and return expression with the builder pattern."""
     """Base class for immutable function defined via arguments and return expression with the builder pattern."""
 
 
     _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
     _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
@@ -470,39 +569,14 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
     )
     )
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
     _function_name: str = dataclasses.field(default="")
     _function_name: str = dataclasses.field(default="")
+    _type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
     _explicit_return: bool = dataclasses.field(default=False)
     _explicit_return: bool = dataclasses.field(default=False)
 
 
-    @cached_property_no_lock
-    def _cached_var_name(self) -> str:
-        """The name of the var.
+    _cached_var_name = cached_property_no_lock(format_args_function_operation)
 
 
-        Returns:
-            The name of the var.
-        """
-        return format_args_function_operation(
-            self._args, self._return_expr, self._explicit_return
-        )
+    _pre_check = pre_check_args
 
 
-    def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
-        """Check if the function can be called with the given arguments.
-
-        Args:
-            *args: The arguments to call the function with.
-
-        Returns:
-            True if the function can be called with the given arguments.
-        """
-        for i, (validator, arg) in enumerate(zip(self._validators, args)):
-            if not validator(arg):
-                arg_name = self._args.args[i] if i < len(self._args.args) else None
-                if arg_name is not None:
-                    raise VarTypeError(
-                        f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
-                    )
-                raise VarTypeError(
-                    f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
-                )
-        return self._validators[len(args) :]
+    _partial_type = figure_partial_type
 
 
     @classmethod
     @classmethod
     def create(
     def create(
@@ -513,6 +587,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
         validators: Sequence[Callable[[Any], bool]] = (),
         validators: Sequence[Callable[[Any], bool]] = (),
         function_name: str = "",
         function_name: str = "",
         explicit_return: bool = False,
         explicit_return: bool = False,
+        type_computer: Optional[TypeComputer] = None,
         _var_type: GenericType = Callable,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
     ):
     ):
@@ -525,6 +600,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
             validators: The validators for the arguments.
             validators: The validators for the arguments.
             function_name: The name of the function.
             function_name: The name of the function.
             explicit_return: Whether to use explicit return syntax.
             explicit_return: Whether to use explicit return syntax.
+            type_computer: A function to compute the return type.
+            _var_type: The type of the var.
             _var_data: Additional hooks and imports associated with the Var.
             _var_data: Additional hooks and imports associated with the Var.
 
 
         Returns:
         Returns:
@@ -539,6 +616,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
             _validators=tuple(validators),
             _validators=tuple(validators),
             _return_expr=return_expr,
             _return_expr=return_expr,
             _explicit_return=explicit_return,
             _explicit_return=explicit_return,
+            _type_computer=type_computer,
         )
         )
 
 
 
 

+ 62 - 49
reflex/vars/number.py

@@ -3,19 +3,11 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import dataclasses
 import dataclasses
+import functools
 import json
 import json
 import math
 import math
 import sys
 import sys
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    NoReturn,
-    Type,
-    TypeVar,
-    Union,
-    overload,
-)
+from typing import TYPE_CHECKING, Any, Callable, NoReturn, TypeVar, Union, overload
 
 
 from reflex.constants.base import Dirs
 from reflex.constants.base import Dirs
 from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
 from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
@@ -25,8 +17,11 @@ from reflex.utils.types import is_optional
 from .base import (
 from .base import (
     CustomVarOperationReturn,
     CustomVarOperationReturn,
     LiteralVar,
     LiteralVar,
+    ReflexCallable,
     Var,
     Var,
     VarData,
     VarData,
+    nary_type_computer,
+    passthrough_unary_type_computer,
     unionize,
     unionize,
     var_operation,
     var_operation,
     var_operation_return,
     var_operation_return,
@@ -544,8 +539,8 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
 
 
 
 
 def binary_number_operation(
 def binary_number_operation(
-    func: Callable[[NumberVar, NumberVar], str],
-) -> Callable[[number_types, number_types], NumberVar]:
+    func: Callable[[Var[int | float], Var[int | float]], str],
+):
     """Decorator to create a binary number operation.
     """Decorator to create a binary number operation.
 
 
     Args:
     Args:
@@ -555,30 +550,37 @@ def binary_number_operation(
         The binary number operation.
         The binary number operation.
     """
     """
 
 
-    @var_operation
-    def operation(lhs: NumberVar, rhs: NumberVar):
+    def operation(
+        lhs: Var[int | float], rhs: Var[int | float]
+    ) -> CustomVarOperationReturn[int | float]:
+        def type_computer(*args: Var):
+            if not args:
+                return (
+                    ReflexCallable[[int | float, int | float], int | float],
+                    type_computer,
+                )
+            if len(args) == 1:
+                return (
+                    ReflexCallable[[int | float], int | float],
+                    functools.partial(type_computer, args[0]),
+                )
+            return (
+                ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
+                None,
+            )
+
         return var_operation_return(
         return var_operation_return(
             js_expression=func(lhs, rhs),
             js_expression=func(lhs, rhs),
-            var_type=unionize(lhs._var_type, rhs._var_type),
+            type_computer=type_computer,
         )
         )
 
 
-    def wrapper(lhs: number_types, rhs: number_types) -> NumberVar:
-        """Create the binary number operation.
-
-        Args:
-            lhs: The first number.
-            rhs: The second number.
-
-        Returns:
-            The binary number operation.
-        """
-        return operation(lhs, rhs)  # type: ignore
+    operation.__name__ = func.__name__
 
 
-    return wrapper
+    return var_operation(operation)
 
 
 
 
 @binary_number_operation
 @binary_number_operation
-def number_add_operation(lhs: NumberVar, rhs: NumberVar):
+def number_add_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Add two numbers.
     """Add two numbers.
 
 
     Args:
     Args:
@@ -592,7 +594,7 @@ def number_add_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 
 
 @binary_number_operation
 @binary_number_operation
-def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
+def number_subtract_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Subtract two numbers.
     """Subtract two numbers.
 
 
     Args:
     Args:
@@ -605,8 +607,15 @@ def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
     return f"({lhs} - {rhs})"
     return f"({lhs} - {rhs})"
 
 
 
 
+unary_operation_type_computer = passthrough_unary_type_computer(
+    ReflexCallable[[int | float], int | float]
+)
+
+
 @var_operation
 @var_operation
-def number_abs_operation(value: NumberVar):
+def number_abs_operation(
+    value: Var[int | float],
+) -> CustomVarOperationReturn[int | float]:
     """Get the absolute value of the number.
     """Get the absolute value of the number.
 
 
     Args:
     Args:
@@ -616,12 +625,12 @@ def number_abs_operation(value: NumberVar):
         The number absolute operation.
         The number absolute operation.
     """
     """
     return var_operation_return(
     return var_operation_return(
-        js_expression=f"Math.abs({value})", var_type=value._var_type
+        js_expression=f"Math.abs({value})", type_computer=unary_operation_type_computer
     )
     )
 
 
 
 
 @binary_number_operation
 @binary_number_operation
-def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
+def number_multiply_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Multiply two numbers.
     """Multiply two numbers.
 
 
     Args:
     Args:
@@ -636,7 +645,7 @@ def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 @var_operation
 @var_operation
 def number_negate_operation(
 def number_negate_operation(
-    value: NumberVar[NUMBER_T],
+    value: Var[NUMBER_T],
 ) -> CustomVarOperationReturn[NUMBER_T]:
 ) -> CustomVarOperationReturn[NUMBER_T]:
     """Negate the number.
     """Negate the number.
 
 
@@ -646,11 +655,13 @@ def number_negate_operation(
     Returns:
     Returns:
         The number negation operation.
         The number negation operation.
     """
     """
-    return var_operation_return(js_expression=f"-({value})", var_type=value._var_type)
+    return var_operation_return(
+        js_expression=f"-({value})", type_computer=unary_operation_type_computer
+    )
 
 
 
 
 @binary_number_operation
 @binary_number_operation
-def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
+def number_true_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Divide two numbers.
     """Divide two numbers.
 
 
     Args:
     Args:
@@ -664,7 +675,7 @@ def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 
 
 @binary_number_operation
 @binary_number_operation
-def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
+def number_floor_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Floor divide two numbers.
     """Floor divide two numbers.
 
 
     Args:
     Args:
@@ -678,7 +689,7 @@ def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 
 
 @binary_number_operation
 @binary_number_operation
-def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
+def number_modulo_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Modulo two numbers.
     """Modulo two numbers.
 
 
     Args:
     Args:
@@ -692,7 +703,7 @@ def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 
 
 @binary_number_operation
 @binary_number_operation
-def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
+def number_exponent_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Exponentiate two numbers.
     """Exponentiate two numbers.
 
 
     Args:
     Args:
@@ -706,7 +717,7 @@ def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 
 
 @var_operation
 @var_operation
-def number_round_operation(value: NumberVar):
+def number_round_operation(value: Var[int | float]):
     """Round the number.
     """Round the number.
 
 
     Args:
     Args:
@@ -719,7 +730,7 @@ def number_round_operation(value: NumberVar):
 
 
 
 
 @var_operation
 @var_operation
-def number_ceil_operation(value: NumberVar):
+def number_ceil_operation(value: Var[int | float]):
     """Ceil the number.
     """Ceil the number.
 
 
     Args:
     Args:
@@ -732,7 +743,7 @@ def number_ceil_operation(value: NumberVar):
 
 
 
 
 @var_operation
 @var_operation
-def number_floor_operation(value: NumberVar):
+def number_floor_operation(value: Var[int | float]):
     """Floor the number.
     """Floor the number.
 
 
     Args:
     Args:
@@ -745,7 +756,7 @@ def number_floor_operation(value: NumberVar):
 
 
 
 
 @var_operation
 @var_operation
-def number_trunc_operation(value: NumberVar):
+def number_trunc_operation(value: Var[int | float]):
     """Trunc the number.
     """Trunc the number.
 
 
     Args:
     Args:
@@ -838,7 +849,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
 
 
 
 
 @var_operation
 @var_operation
-def boolean_to_number_operation(value: BooleanVar):
+def boolean_to_number_operation(value: Var[bool]):
     """Convert the boolean to a number.
     """Convert the boolean to a number.
 
 
     Args:
     Args:
@@ -969,7 +980,7 @@ def not_equal_operation(lhs: Var, rhs: Var):
 
 
 
 
 @var_operation
 @var_operation
-def boolean_not_operation(value: BooleanVar):
+def boolean_not_operation(value: Var[bool]):
     """Boolean NOT the boolean.
     """Boolean NOT the boolean.
 
 
     Args:
     Args:
@@ -1117,7 +1128,7 @@ U = TypeVar("U")
 
 
 @var_operation
 @var_operation
 def ternary_operation(
 def ternary_operation(
-    condition: BooleanVar, if_true: Var[T], if_false: Var[U]
+    condition: Var[bool], if_true: Var[T], if_false: Var[U]
 ) -> CustomVarOperationReturn[Union[T, U]]:
 ) -> CustomVarOperationReturn[Union[T, U]]:
     """Create a ternary operation.
     """Create a ternary operation.
 
 
@@ -1129,12 +1140,14 @@ def ternary_operation(
     Returns:
     Returns:
         The ternary operation.
         The ternary operation.
     """
     """
-    type_value: Union[Type[T], Type[U]] = unionize(
-        if_true._var_type, if_false._var_type
-    )
     value: CustomVarOperationReturn[Union[T, U]] = var_operation_return(
     value: CustomVarOperationReturn[Union[T, U]] = var_operation_return(
         js_expression=f"({condition} ? {if_true} : {if_false})",
         js_expression=f"({condition} ? {if_true} : {if_false})",
-        var_type=type_value,
+        type_computer=nary_type_computer(
+            ReflexCallable[[bool, Any, Any], Any],
+            ReflexCallable[[Any, Any], Any],
+            ReflexCallable[[Any], Any],
+            computer=lambda args: unionize(args[1]._var_type, args[2]._var_type),
+        ),
     )
     )
     return value
     return value
 
 

+ 31 - 12
reflex/vars/object.py

@@ -21,15 +21,23 @@ from typing import (
 
 
 from reflex.utils import types
 from reflex.utils import types
 from reflex.utils.exceptions import VarAttributeError
 from reflex.utils.exceptions import VarAttributeError
-from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
+from reflex.utils.types import (
+    GenericType,
+    get_attribute_access_type,
+    get_origin,
+    unionize,
+)
 
 
 from .base import (
 from .base import (
     CachedVarOperation,
     CachedVarOperation,
     LiteralVar,
     LiteralVar,
+    ReflexCallable,
     Var,
     Var,
     VarData,
     VarData,
     cached_property_no_lock,
     cached_property_no_lock,
     figure_out_type,
     figure_out_type,
+    nary_type_computer,
+    unary_type_computer,
     var_operation,
     var_operation,
     var_operation_return,
     var_operation_return,
 )
 )
@@ -406,7 +414,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
 
 
 
 
 @var_operation
 @var_operation
-def object_keys_operation(value: ObjectVar):
+def object_keys_operation(value: Var):
     """Get the keys of an object.
     """Get the keys of an object.
 
 
     Args:
     Args:
@@ -422,7 +430,7 @@ def object_keys_operation(value: ObjectVar):
 
 
 
 
 @var_operation
 @var_operation
-def object_values_operation(value: ObjectVar):
+def object_values_operation(value: Var):
     """Get the values of an object.
     """Get the values of an object.
 
 
     Args:
     Args:
@@ -433,12 +441,15 @@ def object_values_operation(value: ObjectVar):
     """
     """
     return var_operation_return(
     return var_operation_return(
         js_expression=f"Object.values({value})",
         js_expression=f"Object.values({value})",
-        var_type=List[value._value_type()],
+        type_computer=unary_type_computer(
+            ReflexCallable[[Any], List[Any]],
+            lambda x: List[x.to(ObjectVar)._value_type()],
+        ),
     )
     )
 
 
 
 
 @var_operation
 @var_operation
-def object_entries_operation(value: ObjectVar):
+def object_entries_operation(value: Var):
     """Get the entries of an object.
     """Get the entries of an object.
 
 
     Args:
     Args:
@@ -447,14 +458,18 @@ def object_entries_operation(value: ObjectVar):
     Returns:
     Returns:
         The entries of the object.
         The entries of the object.
     """
     """
+    value = value.to(ObjectVar)
     return var_operation_return(
     return var_operation_return(
         js_expression=f"Object.entries({value})",
         js_expression=f"Object.entries({value})",
-        var_type=List[Tuple[str, value._value_type()]],
+        type_computer=unary_type_computer(
+            ReflexCallable[[Any], List[Tuple[str, Any]]],
+            lambda x: List[Tuple[str, x.to(ObjectVar)._value_type()]],
+        ),
     )
     )
 
 
 
 
 @var_operation
 @var_operation
-def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
+def object_merge_operation(lhs: Var, rhs: Var):
     """Merge two objects.
     """Merge two objects.
 
 
     Args:
     Args:
@@ -466,10 +481,14 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
     """
     """
     return var_operation_return(
     return var_operation_return(
         js_expression=f"({{...{lhs}, ...{rhs}}})",
         js_expression=f"({{...{lhs}, ...{rhs}}})",
-        var_type=Dict[
-            Union[lhs._key_type(), rhs._key_type()],
-            Union[lhs._value_type(), rhs._value_type()],
-        ],
+        type_computer=nary_type_computer(
+            ReflexCallable[[Any, Any], Dict[Any, Any]],
+            ReflexCallable[[Any], Dict[Any, Any]],
+            computer=lambda args: Dict[
+                unionize(*[arg.to(ObjectVar)._key_type() for arg in args]),
+                unionize(*[arg.to(ObjectVar)._value_type() for arg in args]),
+            ],
+        ),
     )
     )
 
 
 
 
@@ -526,7 +545,7 @@ class ObjectItemOperation(CachedVarOperation, Var):
 
 
 
 
 @var_operation
 @var_operation
-def object_has_own_property_operation(object: ObjectVar, key: Var):
+def object_has_own_property_operation(object: Var, key: Var):
     """Check if an object has a key.
     """Check if an object has a key.
 
 
     Args:
     Args:

+ 124 - 62
reflex/vars/sequence.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import dataclasses
 import dataclasses
+import functools
 import inspect
 import inspect
 import json
 import json
 import re
 import re
@@ -34,6 +35,7 @@ from .base import (
     CachedVarOperation,
     CachedVarOperation,
     CustomVarOperationReturn,
     CustomVarOperationReturn,
     LiteralVar,
     LiteralVar,
+    ReflexCallable,
     Var,
     Var,
     VarData,
     VarData,
     _global_vars,
     _global_vars,
@@ -41,7 +43,10 @@ from .base import (
     figure_out_type,
     figure_out_type,
     get_python_literal,
     get_python_literal,
     get_unique_variable_name,
     get_unique_variable_name,
+    nary_type_computer,
+    passthrough_unary_type_computer,
     unionize,
     unionize,
+    unwrap_reflex_callalbe,
     var_operation,
     var_operation,
     var_operation_return,
     var_operation_return,
 )
 )
@@ -353,7 +358,7 @@ class StringVar(Var[STRING_TYPE], python_types=str):
 
 
 
 
 @var_operation
 @var_operation
-def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
+def string_lt_operation(lhs: Var[str], rhs: Var[str]):
     """Check if a string is less than another string.
     """Check if a string is less than another string.
 
 
     Args:
     Args:
@@ -367,7 +372,7 @@ def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
 
 
 
 
 @var_operation
 @var_operation
-def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
+def string_gt_operation(lhs: Var[str], rhs: Var[str]):
     """Check if a string is greater than another string.
     """Check if a string is greater than another string.
 
 
     Args:
     Args:
@@ -381,7 +386,7 @@ def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
 
 
 
 
 @var_operation
 @var_operation
-def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
+def string_le_operation(lhs: Var[str], rhs: Var[str]):
     """Check if a string is less than or equal to another string.
     """Check if a string is less than or equal to another string.
 
 
     Args:
     Args:
@@ -395,7 +400,7 @@ def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
 
 
 
 
 @var_operation
 @var_operation
-def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
+def string_ge_operation(lhs: Var[str], rhs: Var[str]):
     """Check if a string is greater than or equal to another string.
     """Check if a string is greater than or equal to another string.
 
 
     Args:
     Args:
@@ -409,7 +414,7 @@ def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
 
 
 
 
 @var_operation
 @var_operation
-def string_lower_operation(string: StringVar[Any]):
+def string_lower_operation(string: Var[str]):
     """Convert a string to lowercase.
     """Convert a string to lowercase.
 
 
     Args:
     Args:
@@ -422,7 +427,7 @@ def string_lower_operation(string: StringVar[Any]):
 
 
 
 
 @var_operation
 @var_operation
-def string_upper_operation(string: StringVar[Any]):
+def string_upper_operation(string: Var[str]):
     """Convert a string to uppercase.
     """Convert a string to uppercase.
 
 
     Args:
     Args:
@@ -435,7 +440,7 @@ def string_upper_operation(string: StringVar[Any]):
 
 
 
 
 @var_operation
 @var_operation
-def string_strip_operation(string: StringVar[Any]):
+def string_strip_operation(string: Var[str]):
     """Strip a string.
     """Strip a string.
 
 
     Args:
     Args:
@@ -449,7 +454,7 @@ def string_strip_operation(string: StringVar[Any]):
 
 
 @var_operation
 @var_operation
 def string_contains_field_operation(
 def string_contains_field_operation(
-    haystack: StringVar[Any], needle: StringVar[Any] | str, field: StringVar[Any] | str
+    haystack: Var[str], needle: Var[str], field: Var[str]
 ):
 ):
     """Check if a string contains another string.
     """Check if a string contains another string.
 
 
@@ -468,7 +473,7 @@ def string_contains_field_operation(
 
 
 
 
 @var_operation
 @var_operation
-def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | str):
+def string_contains_operation(haystack: Var[str], needle: Var[str]):
     """Check if a string contains another string.
     """Check if a string contains another string.
 
 
     Args:
     Args:
@@ -484,9 +489,7 @@ def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] |
 
 
 
 
 @var_operation
 @var_operation
-def string_starts_with_operation(
-    full_string: StringVar[Any], prefix: StringVar[Any] | str
-):
+def string_starts_with_operation(full_string: Var[str], prefix: Var[str]):
     """Check if a string starts with a prefix.
     """Check if a string starts with a prefix.
 
 
     Args:
     Args:
@@ -502,7 +505,7 @@ def string_starts_with_operation(
 
 
 
 
 @var_operation
 @var_operation
-def string_item_operation(string: StringVar[Any], index: NumberVar | int):
+def string_item_operation(string: Var[str], index: Var[int]):
     """Get an item from a string.
     """Get an item from a string.
 
 
     Args:
     Args:
@@ -515,23 +518,9 @@ def string_item_operation(string: StringVar[Any], index: NumberVar | int):
     return var_operation_return(js_expression=f"{string}.at({index})", var_type=str)
     return var_operation_return(js_expression=f"{string}.at({index})", var_type=str)
 
 
 
 
-@var_operation
-def array_join_operation(array: ArrayVar, sep: StringVar[Any] | str = ""):
-    """Join the elements of an array.
-
-    Args:
-        array: The array.
-        sep: The separator.
-
-    Returns:
-        The joined elements.
-    """
-    return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str)
-
-
 @var_operation
 @var_operation
 def string_replace_operation(
 def string_replace_operation(
-    string: StringVar, search_value: StringVar | str, new_value: StringVar | str
+    string: Var[str], search_value: Var[str], new_value: Var[str]
 ):
 ):
     """Replace a string with a value.
     """Replace a string with a value.
 
 
@@ -1046,7 +1035,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
         Returns:
         Returns:
             The array pluck operation.
             The array pluck operation.
         """
         """
-        return array_pluck_operation(self, field)
+        return array_pluck_operation(self, field).guess_type()
 
 
     @overload
     @overload
     def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: ...
     def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: ...
@@ -1300,7 +1289,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
 
 
 
 
 @var_operation
 @var_operation
-def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""):
+def string_split_operation(string: Var[str], sep: Var[str]):
     """Split a string.
     """Split a string.
 
 
     Args:
     Args:
@@ -1394,9 +1383,9 @@ class ArraySliceOperation(CachedVarOperation, ArrayVar):
 
 
 @var_operation
 @var_operation
 def array_pluck_operation(
 def array_pluck_operation(
-    array: ArrayVar[ARRAY_VAR_TYPE],
-    field: StringVar | str,
-) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
+    array: Var[ARRAY_VAR_TYPE],
+    field: Var[str],
+) -> CustomVarOperationReturn[List]:
     """Pluck a field from an array of objects.
     """Pluck a field from an array of objects.
 
 
     Args:
     Args:
@@ -1408,13 +1397,27 @@ def array_pluck_operation(
     """
     """
     return var_operation_return(
     return var_operation_return(
         js_expression=f"{array}.map(e=>e?.[{field}])",
         js_expression=f"{array}.map(e=>e?.[{field}])",
-        var_type=array._var_type,
+        var_type=List[Any],
     )
     )
 
 
 
 
+@var_operation
+def array_join_operation(array: Var[ARRAY_VAR_TYPE], sep: Var[str]):
+    """Join the elements of an array.
+
+    Args:
+        array: The array.
+        sep: The separator.
+
+    Returns:
+        The joined elements.
+    """
+    return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str)
+
+
 @var_operation
 @var_operation
 def array_reverse_operation(
 def array_reverse_operation(
-    array: ArrayVar[ARRAY_VAR_TYPE],
+    array: Var[ARRAY_VAR_TYPE],
 ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
 ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
     """Reverse an array.
     """Reverse an array.
 
 
@@ -1426,12 +1429,12 @@ def array_reverse_operation(
     """
     """
     return var_operation_return(
     return var_operation_return(
         js_expression=f"{array}.slice().reverse()",
         js_expression=f"{array}.slice().reverse()",
-        var_type=array._var_type,
+        type_computer=passthrough_unary_type_computer(ReflexCallable[[List], List]),
     )
     )
 
 
 
 
 @var_operation
 @var_operation
-def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
+def array_lt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
     """Check if an array is less than another array.
     """Check if an array is less than another array.
 
 
     Args:
     Args:
@@ -1445,7 +1448,7 @@ def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
 
 
 
 
 @var_operation
 @var_operation
-def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
+def array_gt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
     """Check if an array is greater than another array.
     """Check if an array is greater than another array.
 
 
     Args:
     Args:
@@ -1459,7 +1462,7 @@ def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
 
 
 
 
 @var_operation
 @var_operation
-def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
+def array_le_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
     """Check if an array is less than or equal to another array.
     """Check if an array is less than or equal to another array.
 
 
     Args:
     Args:
@@ -1473,7 +1476,7 @@ def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
 
 
 
 
 @var_operation
 @var_operation
-def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
+def array_ge_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
     """Check if an array is greater than or equal to another array.
     """Check if an array is greater than or equal to another array.
 
 
     Args:
     Args:
@@ -1487,7 +1490,7 @@ def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
 
 
 
 
 @var_operation
 @var_operation
-def array_length_operation(array: ArrayVar):
+def array_length_operation(array: Var[ARRAY_VAR_TYPE]):
     """Get the length of an array.
     """Get the length of an array.
 
 
     Args:
     Args:
@@ -1517,7 +1520,7 @@ def is_tuple_type(t: GenericType) -> bool:
 
 
 
 
 @var_operation
 @var_operation
-def array_item_operation(array: ArrayVar, index: NumberVar | int):
+def array_item_operation(array: Var[ARRAY_VAR_TYPE], index: Var[int]):
     """Get an item from an array.
     """Get an item from an array.
 
 
     Args:
     Args:
@@ -1527,23 +1530,45 @@ def array_item_operation(array: ArrayVar, index: NumberVar | int):
     Returns:
     Returns:
         The item from the array.
         The item from the array.
     """
     """
-    args = typing.get_args(array._var_type)
-    if args and isinstance(index, LiteralNumberVar) and is_tuple_type(array._var_type):
-        index_value = int(index._var_value)
-        element_type = args[index_value % len(args)]
-    else:
-        element_type = unionize(*args)
+
+    def type_computer(*args):
+        if len(args) == 0:
+            return (
+                ReflexCallable[[List[Any], int], Any],
+                functools.partial(type_computer, *args),
+            )
+
+        array = args[0]
+        array_args = typing.get_args(array._var_type)
+
+        if len(args) == 1:
+            return (
+                ReflexCallable[[int], unionize(*array_args)],
+                functools.partial(type_computer, *args),
+            )
+
+        index = args[1]
+
+        if (
+            array_args
+            and isinstance(index, LiteralNumberVar)
+            and is_tuple_type(array._var_type)
+        ):
+            index_value = int(index._var_value)
+            element_type = array_args[index_value % len(array_args)]
+        else:
+            element_type = unionize(*array_args)
+
+        return (ReflexCallable[[], element_type], None)
 
 
     return var_operation_return(
     return var_operation_return(
         js_expression=f"{str(array)}.at({str(index)})",
         js_expression=f"{str(array)}.at({str(index)})",
-        var_type=element_type,
+        type_computer=type_computer,
     )
     )
 
 
 
 
 @var_operation
 @var_operation
-def array_range_operation(
-    start: NumberVar | int, stop: NumberVar | int, step: NumberVar | int
-):
+def array_range_operation(start: Var[int], stop: Var[int], step: Var[int]):
     """Create a range of numbers.
     """Create a range of numbers.
 
 
     Args:
     Args:
@@ -1562,7 +1587,7 @@ def array_range_operation(
 
 
 @var_operation
 @var_operation
 def array_contains_field_operation(
 def array_contains_field_operation(
-    haystack: ArrayVar, needle: Any | Var, field: StringVar | str
+    haystack: Var[ARRAY_VAR_TYPE], needle: Var, field: Var[str]
 ):
 ):
     """Check if an array contains an element.
     """Check if an array contains an element.
 
 
@@ -1581,7 +1606,7 @@ def array_contains_field_operation(
 
 
 
 
 @var_operation
 @var_operation
-def array_contains_operation(haystack: ArrayVar, needle: Any | Var):
+def array_contains_operation(haystack: Var[ARRAY_VAR_TYPE], needle: Var):
     """Check if an array contains an element.
     """Check if an array contains an element.
 
 
     Args:
     Args:
@@ -1599,7 +1624,7 @@ def array_contains_operation(haystack: ArrayVar, needle: Any | Var):
 
 
 @var_operation
 @var_operation
 def repeat_array_operation(
 def repeat_array_operation(
-    array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int
+    array: Var[ARRAY_VAR_TYPE], count: Var[int]
 ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
 ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
     """Repeat an array a number of times.
     """Repeat an array a number of times.
 
 
@@ -1610,20 +1635,34 @@ def repeat_array_operation(
     Returns:
     Returns:
         The repeated array.
         The repeated array.
     """
     """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (
+                ReflexCallable[[List[Any], int], List[Any]],
+                type_computer,
+            )
+        if len(args) == 1:
+            return (
+                ReflexCallable[[int], args[0]._var_type],
+                functools.partial(type_computer, *args),
+            )
+        return (ReflexCallable[[], args[0]._var_type], None)
+
     return var_operation_return(
     return var_operation_return(
         js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})",
         js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})",
-        var_type=array._var_type,
+        type_computer=type_computer,
     )
     )
 
 
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from .function import FunctionVar
+    pass
 
 
 
 
 @var_operation
 @var_operation
 def map_array_operation(
 def map_array_operation(
-    array: ArrayVar[ARRAY_VAR_TYPE],
-    function: FunctionVar,
+    array: Var[ARRAY_VAR_TYPE],
+    function: Var[ReflexCallable],
 ):
 ):
     """Map a function over an array.
     """Map a function over an array.
 
 
@@ -1634,14 +1673,33 @@ def map_array_operation(
     Returns:
     Returns:
         The mapped array.
         The mapped array.
     """
     """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (
+                ReflexCallable[[List[Any], ReflexCallable], List[Any]],
+                type_computer,
+            )
+        if len(args) == 1:
+            return (
+                ReflexCallable[[ReflexCallable], List[Any]],
+                functools.partial(type_computer, *args),
+            )
+        return (ReflexCallable[[], List[args[0]._var_type]], None)
+
     return var_operation_return(
     return var_operation_return(
-        js_expression=f"{array}.map({function})", var_type=List[Any]
+        js_expression=f"{array}.map({function})",
+        type_computer=nary_type_computer(
+            ReflexCallable[[List[Any], ReflexCallable], List[Any]],
+            ReflexCallable[[ReflexCallable], List[Any]],
+            computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]],
+        ),
     )
     )
 
 
 
 
 @var_operation
 @var_operation
 def array_concat_operation(
 def array_concat_operation(
-    lhs: ArrayVar[ARRAY_VAR_TYPE], rhs: ArrayVar[ARRAY_VAR_TYPE]
+    lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]
 ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
 ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
     """Concatenate two arrays.
     """Concatenate two arrays.
 
 
@@ -1654,7 +1712,11 @@ def array_concat_operation(
     """
     """
     return var_operation_return(
     return var_operation_return(
         js_expression=f"[...{lhs}, ...{rhs}]",
         js_expression=f"[...{lhs}, ...{rhs}]",
-        var_type=Union[lhs._var_type, rhs._var_type],
+        type_computer=nary_type_computer(
+            ReflexCallable[[List[Any], List[Any]], List[Any]],
+            ReflexCallable[[List[Any]], List[Any]],
+            computer=lambda args: unionize(args[0]._var_type, args[1]._var_type),
+        ),
     )
     )
 
 
 
 

+ 2 - 2
tests/units/test_var.py

@@ -963,11 +963,11 @@ def test_function_var():
 
 
 def test_var_operation():
 def test_var_operation():
     @var_operation
     @var_operation
-    def add(a: Union[NumberVar, int], b: Union[NumberVar, int]):
+    def add(a: Var[int], b: Var[int]):
         return var_operation_return(js_expression=f"({a} + {b})", var_type=int)
         return var_operation_return(js_expression=f"({a} + {b})", var_type=int)
 
 
     assert str(add(1, 2)) == "(1 + 2)"
     assert str(add(1, 2)) == "(1 + 2)"
-    assert str(add(a=4, b=-9)) == "(4 + -9)"
+    assert str(add(4, -9)) == "(4 + -9)"
 
 
     five = LiteralNumberVar.create(5)
     five = LiteralNumberVar.create(5)
     seven = add(2, five)
     seven = add(2, five)