Browse Source

implement type computers

Khaleel Al-Adhami 6 months ago
parent
commit
f4aa1f58c3
6 changed files with 699 additions and 317 deletions
  1. 325 115
      reflex/vars/base.py
  2. 155 77
      reflex/vars/function.py
  3. 62 49
      reflex/vars/number.py
  4. 31 12
      reflex/vars/object.py
  5. 124 62
      reflex/vars/sequence.py
  6. 2 2
      tests/units/test_var.py

+ 325 - 115
reflex/vars/base.py

@@ -14,7 +14,7 @@ import re
 import string
 import sys
 import warnings
-from types import CodeType, FunctionType
+from types import CodeType, EllipsisType, FunctionType
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -26,7 +26,6 @@ from typing import (
     Iterable,
     List,
     Literal,
-    NoReturn,
     Optional,
     Set,
     Tuple,
@@ -38,7 +37,14 @@ from typing import (
     overload,
 )
 
-from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
+from typing_extensions import (
+    ParamSpec,
+    Protocol,
+    TypeGuard,
+    deprecated,
+    get_type_hints,
+    override,
+)
 
 from reflex import constants
 from reflex.base import Base
@@ -69,6 +75,7 @@ from reflex.utils.types import (
 if TYPE_CHECKING:
     from reflex.state import BaseState
 
+    from .function import ArgsFunctionOperation, ReflexCallable
     from .number import BooleanVar, NumberVar
     from .object import ObjectVar
     from .sequence import ArrayVar, StringVar
@@ -79,6 +86,36 @@ OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
 
 warnings.filterwarnings("ignore", message="fields may not start with an underscore")
 
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class ReflexCallable(Protocol[P, R]):
+    """Protocol for a callable."""
+
+    __call__: Callable[P, R]
+
+
+def unwrap_reflex_callalbe(
+    callable_type: GenericType,
+) -> Tuple[Union[EllipsisType, Tuple[GenericType, ...]], GenericType]:
+    """Unwrap the ReflexCallable type.
+
+    Args:
+        callable_type: The ReflexCallable type to unwrap.
+
+    Returns:
+        The unwrapped ReflexCallable type.
+    """
+    if callable_type is ReflexCallable:
+        return Ellipsis, Any
+    if get_origin(callable_type) is not ReflexCallable:
+        return Ellipsis, Any
+    args = get_args(callable_type)
+    if not args or len(args) != 2:
+        return Ellipsis, Any
+    return args
+
 
 @dataclasses.dataclass(
     eq=False,
@@ -409,9 +446,11 @@ class Var(Generic[VAR_TYPE]):
 
         if _var_data or _js_expr != self._js_expr:
             self.__init__(
-                _js_expr=_js_expr,
-                _var_type=self._var_type,
-                _var_data=VarData.merge(self._var_data, _var_data),
+                **{
+                    **dataclasses.asdict(self),
+                    "_js_expr": _js_expr,
+                    "_var_data": VarData.merge(self._var_data, _var_data),
+                }
             )
 
     def __hash__(self) -> int:
@@ -690,6 +729,12 @@ class Var(Generic[VAR_TYPE]):
     @overload
     def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
 
+    @overload
+    def guess_type(self: Var[list] | Var[tuple] | Var[set]) -> ArrayVar: ...
+
+    @overload
+    def guess_type(self: Var[dict]) -> ObjectVar[dict]: ...
+
     @overload
     def guess_type(self) -> Self: ...
 
@@ -1413,71 +1458,94 @@ def get_python_literal(value: Union[LiteralVar, Any]) -> Any | None:
     return value
 
 
+def validate_arg(type_hint: GenericType) -> Callable[[Any], bool]:
+    """Create a validator for an argument.
+
+    Args:
+        type_hint: The type hint of the argument.
+
+    Returns:
+        The validator.
+    """
+
+    def validate(value: Any):
+        return True
+
+    return validate
+
+
 P = ParamSpec("P")
 T = TypeVar("T")
+V1 = TypeVar("V1")
+V2 = TypeVar("V2")
+V3 = TypeVar("V3")
+V4 = TypeVar("V4")
+V5 = TypeVar("V5")
 
 
-# NoReturn is used to match CustomVarOperationReturn with no type hint.
-@overload
-def var_operation(
-    func: Callable[P, CustomVarOperationReturn[NoReturn]],
-) -> Callable[P, Var]: ...
-
+class TypeComputer(Protocol):
+    """A protocol for type computers."""
 
-@overload
-def var_operation(
-    func: Callable[P, CustomVarOperationReturn[bool]],
-) -> Callable[P, BooleanVar]: ...
+    def __call__(self, *args: Var) -> Tuple[GenericType, Union[TypeComputer, None]]:
+        """Compute the type of the operation.
 
+        Args:
+            *args: The arguments to compute the type of.
 
-NUMBER_T = TypeVar("NUMBER_T", int, float, Union[int, float])
+        Returns:
+            The type of the operation.
+        """
+        ...
 
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[NUMBER_T]],
-) -> Callable[P, NumberVar[NUMBER_T]]: ...
+    func: Callable[[], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[], T]]: ...
 
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[str]],
-) -> Callable[P, StringVar]: ...
-
-
-LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set])
+    func: Callable[[Var[V1]], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[V1], T]]: ...
 
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[LIST_T]],
-) -> Callable[P, ArrayVar[LIST_T]]: ...
+    func: Callable[[Var[V1], Var[V2]], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[V1, V2], T]]: ...
 
 
-OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
+@overload
+def var_operation(
+    func: Callable[[Var[V1], Var[V2], Var[V3]], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3], T]]: ...
 
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]],
-) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...
+    func: Callable[[Var[V1], Var[V2], Var[V3], Var[V4]], CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4], T]]: ...
 
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[T]],
-) -> Callable[P, Var[T]]: ...
+    func: Callable[
+        [Var[V1], Var[V2], Var[V3], Var[V4], Var[V5]],
+        CustomVarOperationReturn[T],
+    ],
+) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4, V5], T]]: ...
 
 
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[T]],
-) -> Callable[P, Var[T]]:
+    func: Callable[..., CustomVarOperationReturn[T]],
+) -> ArgsFunctionOperation:
     """Decorator for creating a var operation.
 
     Example:
     ```python
     @var_operation
-    def add(a: NumberVar, b: NumberVar):
+    def add(a: Var[int], b: Var[int]):
         return custom_var_operation(f"{a} + {b}")
     ```
 
@@ -1487,26 +1555,61 @@ def var_operation(
     Returns:
         The decorated function.
     """
+    from .function import ArgsFunctionOperation, ReflexCallable
 
-    @functools.wraps(func)
-    def wrapper(*args: P.args, **kwargs: P.kwargs) -> Var[T]:
-        func_args = list(inspect.signature(func).parameters)
-        args_vars = {
-            func_args[i]: (LiteralVar.create(arg) if not isinstance(arg, Var) else arg)
-            for i, arg in enumerate(args)
-        }
-        kwargs_vars = {
-            key: LiteralVar.create(value) if not isinstance(value, Var) else value
-            for key, value in kwargs.items()
-        }
-
-        return CustomVarOperation.create(
-            name=func.__name__,
-            args=tuple(list(args_vars.items()) + list(kwargs_vars.items())),
-            return_var=func(*args_vars.values(), **kwargs_vars),  # type: ignore
-        ).guess_type()
+    func_name = func.__name__
 
-    return wrapper
+    func_arg_spec = inspect.getfullargspec(func)
+
+    if func_arg_spec.kwonlyargs:
+        raise TypeError(f"Function {func_name} cannot have keyword-only arguments.")
+    if func_arg_spec.varargs:
+        raise TypeError(f"Function {func_name} cannot have variable arguments.")
+
+    arg_names = func_arg_spec.args
+
+    type_hints = get_type_hints(func)
+
+    if not all(
+        (get_origin((type_hint := type_hints.get(arg_name, Any))) or type_hint) is Var
+        and len(get_args(type_hint)) <= 1
+        for arg_name in arg_names
+    ):
+        raise TypeError(
+            f"Function {func_name} must have type hints of the form `Var[Type]`."
+        )
+
+    args_with_type_hints = tuple(
+        (arg_name, (args[0] if (args := get_args(type_hints[arg_name])) else Any))
+        for arg_name in arg_names
+    )
+
+    arg_vars = tuple(
+        (
+            Var("_" + arg_name, _var_type=arg_python_type)
+            if not isinstance(arg_python_type, TypeVar)
+            else Var("_" + arg_name)
+        )
+        for arg_name, arg_python_type in args_with_type_hints
+    )
+
+    custom_operation_return = func(*arg_vars)
+
+    args_operation = ArgsFunctionOperation.create(
+        tuple(map(str, arg_vars)),
+        custom_operation_return,
+        validators=tuple(
+            validate_arg(type_hints.get(arg_name, Any)) for arg_name in arg_names
+        ),
+        function_name=func_name,
+        type_computer=custom_operation_return._type_computer,
+        _var_type=ReflexCallable[
+            tuple(arg_python_type for _, arg_python_type in args_with_type_hints),
+            custom_operation_return._var_type,
+        ],
+    )
+
+    return args_operation
 
 
 def figure_out_type(value: Any) -> types.GenericType:
@@ -1621,66 +1724,6 @@ class CachedVarOperation:
         )
 
 
-def and_operation(a: Var | Any, b: Var | Any) -> Var:
-    """Perform a logical AND operation on two variables.
-
-    Args:
-        a: The first variable.
-        b: The second variable.
-
-    Returns:
-        The result of the logical AND operation.
-    """
-    return _and_operation(a, b)  # type: ignore
-
-
-@var_operation
-def _and_operation(a: Var, b: Var):
-    """Perform a logical AND operation on two variables.
-
-    Args:
-        a: The first variable.
-        b: The second variable.
-
-    Returns:
-        The result of the logical AND operation.
-    """
-    return var_operation_return(
-        js_expression=f"({a} && {b})",
-        var_type=unionize(a._var_type, b._var_type),
-    )
-
-
-def or_operation(a: Var | Any, b: Var | Any) -> Var:
-    """Perform a logical OR operation on two variables.
-
-    Args:
-        a: The first variable.
-        b: The second variable.
-
-    Returns:
-        The result of the logical OR operation.
-    """
-    return _or_operation(a, b)  # type: ignore
-
-
-@var_operation
-def _or_operation(a: Var, b: Var):
-    """Perform a logical OR operation on two variables.
-
-    Args:
-        a: The first variable.
-        b: The second variable.
-
-    Returns:
-        The result of the logical OR operation.
-    """
-    return var_operation_return(
-        js_expression=f"({a} || {b})",
-        var_type=unionize(a._var_type, b._var_type),
-    )
-
-
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
@@ -2289,14 +2332,22 @@ def computed_var(
 RETURN = TypeVar("RETURN")
 
 
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
 class CustomVarOperationReturn(Var[RETURN]):
     """Base class for custom var operations."""
 
+    _type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
+
     @classmethod
     def create(
         cls,
         js_expression: str,
         _var_type: Type[RETURN] | None = None,
+        _type_computer: Optional[TypeComputer] = None,
         _var_data: VarData | None = None,
     ) -> CustomVarOperationReturn[RETURN]:
         """Create a CustomVarOperation.
@@ -2304,6 +2355,7 @@ class CustomVarOperationReturn(Var[RETURN]):
         Args:
             js_expression: The JavaScript expression to evaluate.
             _var_type: The type of the var.
+            _type_computer: A function to compute the type of the var given the arguments.
             _var_data: Additional hooks and imports associated with the Var.
 
         Returns:
@@ -2312,6 +2364,7 @@ class CustomVarOperationReturn(Var[RETURN]):
         return CustomVarOperationReturn(
             _js_expr=js_expression,
             _var_type=_var_type or Any,
+            _type_computer=_type_computer,
             _var_data=_var_data,
         )
 
@@ -2319,6 +2372,7 @@ class CustomVarOperationReturn(Var[RETURN]):
 def var_operation_return(
     js_expression: str,
     var_type: Type[RETURN] | None = None,
+    type_computer: Optional[TypeComputer] = None,
     var_data: VarData | None = None,
 ) -> CustomVarOperationReturn[RETURN]:
     """Shortcut for creating a CustomVarOperationReturn.
@@ -2326,15 +2380,17 @@ def var_operation_return(
     Args:
         js_expression: The JavaScript expression to evaluate.
         var_type: The type of the var.
+        type_computer: A function to compute the type of the var given the arguments.
         var_data: Additional hooks and imports associated with the Var.
 
     Returns:
         The CustomVarOperationReturn.
     """
     return CustomVarOperationReturn.create(
-        js_expression,
-        var_type,
-        var_data,
+        js_expression=js_expression,
+        _var_type=var_type,
+        _type_computer=type_computer,
+        _var_data=var_data,
     )
 
 
@@ -2942,3 +2998,157 @@ def field(value: T) -> Field[T]:
         The Field.
     """
     return value  # type: ignore
+
+
+def and_operation(a: Var | Any, b: Var | Any) -> Var:
+    """Perform a logical AND operation on two variables.
+
+    Args:
+        a: The first variable.
+        b: The second variable.
+
+    Returns:
+        The result of the logical AND operation.
+    """
+    return _and_operation(a, b)  # type: ignore
+
+
+@var_operation
+def _and_operation(a: Var, b: Var):
+    """Perform a logical AND operation on two variables.
+
+    Args:
+        a: The first variable.
+        b: The second variable.
+
+    Returns:
+        The result of the logical AND operation.
+    """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (ReflexCallable[[Any, Any], Any], type_computer)
+        if len(args) == 1:
+            return (
+                ReflexCallable[[Any], Any],
+                functools.partial(type_computer, args[0]),
+            )
+        return (
+            ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
+            None,
+        )
+
+    return var_operation_return(
+        js_expression=f"({a} && {b})",
+        type_computer=type_computer,
+    )
+
+
+def or_operation(a: Var | Any, b: Var | Any) -> Var:
+    """Perform a logical OR operation on two variables.
+
+    Args:
+        a: The first variable.
+        b: The second variable.
+
+    Returns:
+        The result of the logical OR operation.
+    """
+    return _or_operation(a, b)  # type: ignore
+
+
+@var_operation
+def _or_operation(a: Var, b: Var):
+    """Perform a logical OR operation on two variables.
+
+    Args:
+        a: The first variable.
+        b: The second variable.
+
+    Returns:
+        The result of the logical OR operation.
+    """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (ReflexCallable[[Any, Any], Any], type_computer)
+        if len(args) == 1:
+            return (
+                ReflexCallable[[Any], Any],
+                functools.partial(type_computer, args[0]),
+            )
+        return (
+            ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
+            None,
+        )
+
+    return var_operation_return(
+        js_expression=f"({a} || {b})",
+        type_computer=type_computer,
+    )
+
+
+def passthrough_unary_type_computer(no_args: GenericType) -> TypeComputer:
+    """Create a type computer for unary operations.
+
+    Args:
+        no_args: The type to return when no arguments are provided.
+
+    Returns:
+        The type computer.
+    """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (no_args, type_computer)
+        return (ReflexCallable[[], args[0]._var_type], None)
+
+    return type_computer
+
+
+def unary_type_computer(
+    no_args: GenericType, computer: Callable[[Var], GenericType]
+) -> TypeComputer:
+    """Create a type computer for unary operations.
+
+    Args:
+        no_args: The type to return when no arguments are provided.
+        computer: The function to compute the type.
+
+    Returns:
+        The type computer.
+    """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (no_args, type_computer)
+        return (ReflexCallable[[], computer(args[0])], None)
+
+    return type_computer
+
+
+def nary_type_computer(
+    *types: GenericType, computer: Callable[..., GenericType]
+) -> TypeComputer:
+    """Create a type computer for n-ary operations.
+
+    Args:
+        types: The types to return when no arguments are provided.
+        computer: The function to compute the type.
+
+    Returns:
+        The type computer.
+    """
+
+    def type_computer(*args: Var):
+        if len(args) != len(types):
+            return (
+                ReflexCallable[[], types[len(args)]],
+                functools.partial(type_computer, *args),
+            )
+        return (
+            ReflexCallable[[], computer(args)],
+            None,
+        )
+
+    return type_computer

+ 155 - 77
reflex/vars/function.py

@@ -6,28 +6,31 @@ import dataclasses
 import sys
 from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
 
-from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
+from typing_extensions import Concatenate, Generic, ParamSpec, TypeVar
 
 from reflex.utils import format
 from reflex.utils.exceptions import VarTypeError
 from reflex.utils.types import GenericType
 
-from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
+from .base import (
+    CachedVarOperation,
+    LiteralVar,
+    ReflexCallable,
+    TypeComputer,
+    Var,
+    VarData,
+    cached_property_no_lock,
+    unwrap_reflex_callalbe,
+)
 
 P = ParamSpec("P")
+R = TypeVar("R")
 V1 = TypeVar("V1")
 V2 = TypeVar("V2")
 V3 = TypeVar("V3")
 V4 = TypeVar("V4")
 V5 = TypeVar("V5")
 V6 = TypeVar("V6")
-R = TypeVar("R")
-
-
-class ReflexCallable(Protocol[P, R]):
-    """Protocol for a callable."""
-
-    __call__: Callable[P, R]
 
 
 CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
@@ -112,20 +115,37 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
         """
         if not args:
             return self
+
+        args = tuple(map(LiteralVar.create, args))
+
         remaining_validators = self._pre_check(*args)
+
+        partial_types, type_computer = self._partial_type(*args)
+
         if self.__call__ is self.partial:
             # if the default behavior is partial, we should return a new partial function
             return ArgsFunctionOperationBuilder.create(
                 (),
-                VarOperationCall.create(self, *args, Var(_js_expr="...args")),
+                VarOperationCall.create(
+                    self,
+                    *args,
+                    Var(_js_expr="...args"),
+                    _var_type=self._return_type(*args),
+                ),
                 rest="args",
                 validators=remaining_validators,
+                type_computer=type_computer,
+                _var_type=partial_types,
             )
         return ArgsFunctionOperation.create(
             (),
-            VarOperationCall.create(self, *args, Var(_js_expr="...args")),
+            VarOperationCall.create(
+                self, *args, Var(_js_expr="...args"), _var_type=self._return_type(*args)
+            ),
             rest="args",
             validators=remaining_validators,
+            type_computer=type_computer,
+            _var_type=partial_types,
         )
 
     @overload
@@ -194,9 +214,56 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
 
         Returns:
             The function call operation.
+
+        Raises:
+            VarTypeError: If the number of arguments is invalid
         """
+        arg_len = self._arg_len()
+        if arg_len is not None and len(args) != arg_len:
+            raise VarTypeError(f"Invalid number of arguments provided to {str(self)}")
+        args = tuple(map(LiteralVar.create, args))
         self._pre_check(*args)
-        return VarOperationCall.create(self, *args).guess_type()
+        return_type = self._return_type(*args)
+        return VarOperationCall.create(self, *args, _var_type=return_type).guess_type()
+
+    def _partial_type(
+        self, *args: Var | Any
+    ) -> Tuple[GenericType, Optional[TypeComputer]]:
+        """Override the type of the function call with the given arguments.
+
+        Args:
+            *args: The arguments to call the function with.
+
+        Returns:
+            The overridden type of the function call.
+        """
+        args_types, return_type = unwrap_reflex_callalbe(self._var_type)
+        if isinstance(args_types, tuple):
+            return ReflexCallable[[*args_types[len(args) :]], return_type], None
+        return ReflexCallable[..., return_type], None
+
+    def _arg_len(self) -> int | None:
+        """Get the number of arguments the function takes.
+
+        Returns:
+            The number of arguments the function takes.
+        """
+        args_types, _ = unwrap_reflex_callalbe(self._var_type)
+        if isinstance(args_types, tuple):
+            return len(args_types)
+        return None
+
+    def _return_type(self, *args: Var | Any) -> GenericType:
+        """Override the type of the function call with the given arguments.
+
+        Args:
+            *args: The arguments to call the function with.
+
+        Returns:
+            The overridden type of the function call.
+        """
+        partial_types, _ = self._partial_type(*args)
+        return unwrap_reflex_callalbe(partial_types)[1]
 
     def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
         """Check if the function can be called with the given arguments.
@@ -343,11 +410,12 @@ class FunctionArgs:
 
 
 def format_args_function_operation(
-    args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
+    self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
 ) -> str:
     """Format an args function operation.
 
     Args:
+        self: The function operation.
         args: The function arguments.
         return_expr: The return expression.
         explicit_return: Whether to use explicit return syntax.
@@ -356,26 +424,76 @@ def format_args_function_operation(
         The formatted args function operation.
     """
     arg_names_str = ", ".join(
-        [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
-        + ([f"...{args.rest}"] if args.rest else [])
+        [
+            arg if isinstance(arg, str) else arg.to_javascript()
+            for arg in self._args.args
+        ]
+        + ([f"...{self._args.rest}"] if self._args.rest else [])
     )
 
-    return_expr_str = str(LiteralVar.create(return_expr))
+    return_expr_str = str(LiteralVar.create(self._return_expr))
 
     # Wrap return expression in curly braces if explicit return syntax is used.
     return_expr_str_wrapped = (
-        format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
+        format.wrap(return_expr_str, "{", "}")
+        if self._explicit_return
+        else return_expr_str
     )
 
     return f"(({arg_names_str}) => {return_expr_str_wrapped})"
 
 
+def pre_check_args(
+    self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, *args: Var | Any
+) -> Tuple[Callable[[Any], bool], ...]:
+    """Check if the function can be called with the given arguments.
+
+    Args:
+        self: The function operation.
+        *args: The arguments to call the function with.
+
+    Returns:
+        True if the function can be called with the given arguments.
+    """
+    for i, (validator, arg) in enumerate(zip(self._validators, args)):
+        if not validator(arg):
+            arg_name = self._args.args[i] if i < len(self._args.args) else None
+            if arg_name is not None:
+                raise VarTypeError(
+                    f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
+                )
+            raise VarTypeError(
+                f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
+            )
+    return self._validators[len(args) :]
+
+
+def figure_partial_type(
+    self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
+    *args: Var | Any,
+) -> Tuple[GenericType, Optional[TypeComputer]]:
+    """Figure out the return type of the function.
+
+    Args:
+        self: The function operation.
+        *args: The arguments to call the function with.
+
+    Returns:
+        The return type of the function.
+    """
+    return (
+        self._type_computer(*args)
+        if self._type_computer is not None
+        else FunctionVar._partial_type(self, *args)
+    )
+
+
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
-class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
+class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]):
     """Base class for immutable function defined via arguments and return expression."""
 
     _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
@@ -384,39 +502,14 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
     )
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
     _function_name: str = dataclasses.field(default="")
+    _type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
     _explicit_return: bool = dataclasses.field(default=False)
 
-    @cached_property_no_lock
-    def _cached_var_name(self) -> str:
-        """The name of the var.
+    _cached_var_name = cached_property_no_lock(format_args_function_operation)
 
-        Returns:
-            The name of the var.
-        """
-        return format_args_function_operation(
-            self._args, self._return_expr, self._explicit_return
-        )
-
-    def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
-        """Check if the function can be called with the given arguments.
-
-        Args:
-            *args: The arguments to call the function with.
+    _pre_check = pre_check_args
 
-        Returns:
-            True if the function can be called with the given arguments.
-        """
-        for i, (validator, arg) in enumerate(zip(self._validators, args)):
-            if not validator(arg):
-                arg_name = self._args.args[i] if i < len(self._args.args) else None
-                if arg_name is not None:
-                    raise VarTypeError(
-                        f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
-                    )
-                raise VarTypeError(
-                    f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
-                )
-        return self._validators[len(args) :]
+    _partial_type = figure_partial_type
 
     @classmethod
     def create(
@@ -427,6 +520,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         validators: Sequence[Callable[[Any], bool]] = (),
         function_name: str = "",
         explicit_return: bool = False,
+        type_computer: Optional[TypeComputer] = None,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
     ):
@@ -439,6 +533,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             validators: The validators for the arguments.
             function_name: The name of the function.
             explicit_return: Whether to use explicit return syntax.
+            type_computer: A function to compute the return type.
+            _var_type: The type of the var.
             _var_data: Additional hooks and imports associated with the Var.
 
         Returns:
@@ -453,6 +549,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
             _validators=tuple(validators),
             _return_expr=return_expr,
             _explicit_return=explicit_return,
+            _type_computer=type_computer,
         )
 
 
@@ -461,7 +558,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
-class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
+class ArgsFunctionOperationBuilder(
+    CachedVarOperation, BuilderFunctionVar[CALLABLE_TYPE]
+):
     """Base class for immutable function defined via arguments and return expression with the builder pattern."""
 
     _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
@@ -470,39 +569,14 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
     )
     _return_expr: Union[Var, Any] = dataclasses.field(default=None)
     _function_name: str = dataclasses.field(default="")
+    _type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
     _explicit_return: bool = dataclasses.field(default=False)
 
-    @cached_property_no_lock
-    def _cached_var_name(self) -> str:
-        """The name of the var.
+    _cached_var_name = cached_property_no_lock(format_args_function_operation)
 
-        Returns:
-            The name of the var.
-        """
-        return format_args_function_operation(
-            self._args, self._return_expr, self._explicit_return
-        )
+    _pre_check = pre_check_args
 
-    def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
-        """Check if the function can be called with the given arguments.
-
-        Args:
-            *args: The arguments to call the function with.
-
-        Returns:
-            True if the function can be called with the given arguments.
-        """
-        for i, (validator, arg) in enumerate(zip(self._validators, args)):
-            if not validator(arg):
-                arg_name = self._args.args[i] if i < len(self._args.args) else None
-                if arg_name is not None:
-                    raise VarTypeError(
-                        f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
-                    )
-                raise VarTypeError(
-                    f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
-                )
-        return self._validators[len(args) :]
+    _partial_type = figure_partial_type
 
     @classmethod
     def create(
@@ -513,6 +587,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
         validators: Sequence[Callable[[Any], bool]] = (),
         function_name: str = "",
         explicit_return: bool = False,
+        type_computer: Optional[TypeComputer] = None,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
     ):
@@ -525,6 +600,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
             validators: The validators for the arguments.
             function_name: The name of the function.
             explicit_return: Whether to use explicit return syntax.
+            type_computer: A function to compute the return type.
+            _var_type: The type of the var.
             _var_data: Additional hooks and imports associated with the Var.
 
         Returns:
@@ -539,6 +616,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
             _validators=tuple(validators),
             _return_expr=return_expr,
             _explicit_return=explicit_return,
+            _type_computer=type_computer,
         )
 
 

+ 62 - 49
reflex/vars/number.py

@@ -3,19 +3,11 @@
 from __future__ import annotations
 
 import dataclasses
+import functools
 import json
 import math
 import sys
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    NoReturn,
-    Type,
-    TypeVar,
-    Union,
-    overload,
-)
+from typing import TYPE_CHECKING, Any, Callable, NoReturn, TypeVar, Union, overload
 
 from reflex.constants.base import Dirs
 from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
@@ -25,8 +17,11 @@ from reflex.utils.types import is_optional
 from .base import (
     CustomVarOperationReturn,
     LiteralVar,
+    ReflexCallable,
     Var,
     VarData,
+    nary_type_computer,
+    passthrough_unary_type_computer,
     unionize,
     var_operation,
     var_operation_return,
@@ -544,8 +539,8 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
 
 
 def binary_number_operation(
-    func: Callable[[NumberVar, NumberVar], str],
-) -> Callable[[number_types, number_types], NumberVar]:
+    func: Callable[[Var[int | float], Var[int | float]], str],
+):
     """Decorator to create a binary number operation.
 
     Args:
@@ -555,30 +550,37 @@ def binary_number_operation(
         The binary number operation.
     """
 
-    @var_operation
-    def operation(lhs: NumberVar, rhs: NumberVar):
+    def operation(
+        lhs: Var[int | float], rhs: Var[int | float]
+    ) -> CustomVarOperationReturn[int | float]:
+        def type_computer(*args: Var):
+            if not args:
+                return (
+                    ReflexCallable[[int | float, int | float], int | float],
+                    type_computer,
+                )
+            if len(args) == 1:
+                return (
+                    ReflexCallable[[int | float], int | float],
+                    functools.partial(type_computer, args[0]),
+                )
+            return (
+                ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
+                None,
+            )
+
         return var_operation_return(
             js_expression=func(lhs, rhs),
-            var_type=unionize(lhs._var_type, rhs._var_type),
+            type_computer=type_computer,
         )
 
-    def wrapper(lhs: number_types, rhs: number_types) -> NumberVar:
-        """Create the binary number operation.
-
-        Args:
-            lhs: The first number.
-            rhs: The second number.
-
-        Returns:
-            The binary number operation.
-        """
-        return operation(lhs, rhs)  # type: ignore
+    operation.__name__ = func.__name__
 
-    return wrapper
+    return var_operation(operation)
 
 
 @binary_number_operation
-def number_add_operation(lhs: NumberVar, rhs: NumberVar):
+def number_add_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Add two numbers.
 
     Args:
@@ -592,7 +594,7 @@ def number_add_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 @binary_number_operation
-def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
+def number_subtract_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Subtract two numbers.
 
     Args:
@@ -605,8 +607,15 @@ def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
     return f"({lhs} - {rhs})"
 
 
+unary_operation_type_computer = passthrough_unary_type_computer(
+    ReflexCallable[[int | float], int | float]
+)
+
+
 @var_operation
-def number_abs_operation(value: NumberVar):
+def number_abs_operation(
+    value: Var[int | float],
+) -> CustomVarOperationReturn[int | float]:
     """Get the absolute value of the number.
 
     Args:
@@ -616,12 +625,12 @@ def number_abs_operation(value: NumberVar):
         The number absolute operation.
     """
     return var_operation_return(
-        js_expression=f"Math.abs({value})", var_type=value._var_type
+        js_expression=f"Math.abs({value})", type_computer=unary_operation_type_computer
     )
 
 
 @binary_number_operation
-def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
+def number_multiply_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Multiply two numbers.
 
     Args:
@@ -636,7 +645,7 @@ def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
 
 @var_operation
 def number_negate_operation(
-    value: NumberVar[NUMBER_T],
+    value: Var[NUMBER_T],
 ) -> CustomVarOperationReturn[NUMBER_T]:
     """Negate the number.
 
@@ -646,11 +655,13 @@ def number_negate_operation(
     Returns:
         The number negation operation.
     """
-    return var_operation_return(js_expression=f"-({value})", var_type=value._var_type)
+    return var_operation_return(
+        js_expression=f"-({value})", type_computer=unary_operation_type_computer
+    )
 
 
 @binary_number_operation
-def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
+def number_true_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Divide two numbers.
 
     Args:
@@ -664,7 +675,7 @@ def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 @binary_number_operation
-def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
+def number_floor_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Floor divide two numbers.
 
     Args:
@@ -678,7 +689,7 @@ def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 @binary_number_operation
-def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
+def number_modulo_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Modulo two numbers.
 
     Args:
@@ -692,7 +703,7 @@ def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 @binary_number_operation
-def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
+def number_exponent_operation(lhs: Var[int | float], rhs: Var[int | float]):
     """Exponentiate two numbers.
 
     Args:
@@ -706,7 +717,7 @@ def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 @var_operation
-def number_round_operation(value: NumberVar):
+def number_round_operation(value: Var[int | float]):
     """Round the number.
 
     Args:
@@ -719,7 +730,7 @@ def number_round_operation(value: NumberVar):
 
 
 @var_operation
-def number_ceil_operation(value: NumberVar):
+def number_ceil_operation(value: Var[int | float]):
     """Ceil the number.
 
     Args:
@@ -732,7 +743,7 @@ def number_ceil_operation(value: NumberVar):
 
 
 @var_operation
-def number_floor_operation(value: NumberVar):
+def number_floor_operation(value: Var[int | float]):
     """Floor the number.
 
     Args:
@@ -745,7 +756,7 @@ def number_floor_operation(value: NumberVar):
 
 
 @var_operation
-def number_trunc_operation(value: NumberVar):
+def number_trunc_operation(value: Var[int | float]):
     """Trunc the number.
 
     Args:
@@ -838,7 +849,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
 
 
 @var_operation
-def boolean_to_number_operation(value: BooleanVar):
+def boolean_to_number_operation(value: Var[bool]):
     """Convert the boolean to a number.
 
     Args:
@@ -969,7 +980,7 @@ def not_equal_operation(lhs: Var, rhs: Var):
 
 
 @var_operation
-def boolean_not_operation(value: BooleanVar):
+def boolean_not_operation(value: Var[bool]):
     """Boolean NOT the boolean.
 
     Args:
@@ -1117,7 +1128,7 @@ U = TypeVar("U")
 
 @var_operation
 def ternary_operation(
-    condition: BooleanVar, if_true: Var[T], if_false: Var[U]
+    condition: Var[bool], if_true: Var[T], if_false: Var[U]
 ) -> CustomVarOperationReturn[Union[T, U]]:
     """Create a ternary operation.
 
@@ -1129,12 +1140,14 @@ def ternary_operation(
     Returns:
         The ternary operation.
     """
-    type_value: Union[Type[T], Type[U]] = unionize(
-        if_true._var_type, if_false._var_type
-    )
     value: CustomVarOperationReturn[Union[T, U]] = var_operation_return(
         js_expression=f"({condition} ? {if_true} : {if_false})",
-        var_type=type_value,
+        type_computer=nary_type_computer(
+            ReflexCallable[[bool, Any, Any], Any],
+            ReflexCallable[[Any, Any], Any],
+            ReflexCallable[[Any], Any],
+            computer=lambda args: unionize(args[1]._var_type, args[2]._var_type),
+        ),
     )
     return value
 

+ 31 - 12
reflex/vars/object.py

@@ -21,15 +21,23 @@ from typing import (
 
 from reflex.utils import types
 from reflex.utils.exceptions import VarAttributeError
-from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
+from reflex.utils.types import (
+    GenericType,
+    get_attribute_access_type,
+    get_origin,
+    unionize,
+)
 
 from .base import (
     CachedVarOperation,
     LiteralVar,
+    ReflexCallable,
     Var,
     VarData,
     cached_property_no_lock,
     figure_out_type,
+    nary_type_computer,
+    unary_type_computer,
     var_operation,
     var_operation_return,
 )
@@ -406,7 +414,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
 
 
 @var_operation
-def object_keys_operation(value: ObjectVar):
+def object_keys_operation(value: Var):
     """Get the keys of an object.
 
     Args:
@@ -422,7 +430,7 @@ def object_keys_operation(value: ObjectVar):
 
 
 @var_operation
-def object_values_operation(value: ObjectVar):
+def object_values_operation(value: Var):
     """Get the values of an object.
 
     Args:
@@ -433,12 +441,15 @@ def object_values_operation(value: ObjectVar):
     """
     return var_operation_return(
         js_expression=f"Object.values({value})",
-        var_type=List[value._value_type()],
+        type_computer=unary_type_computer(
+            ReflexCallable[[Any], List[Any]],
+            lambda x: List[x.to(ObjectVar)._value_type()],
+        ),
     )
 
 
 @var_operation
-def object_entries_operation(value: ObjectVar):
+def object_entries_operation(value: Var):
     """Get the entries of an object.
 
     Args:
@@ -447,14 +458,18 @@ def object_entries_operation(value: ObjectVar):
     Returns:
         The entries of the object.
     """
+    value = value.to(ObjectVar)
     return var_operation_return(
         js_expression=f"Object.entries({value})",
-        var_type=List[Tuple[str, value._value_type()]],
+        type_computer=unary_type_computer(
+            ReflexCallable[[Any], List[Tuple[str, Any]]],
+            lambda x: List[Tuple[str, x.to(ObjectVar)._value_type()]],
+        ),
     )
 
 
 @var_operation
-def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
+def object_merge_operation(lhs: Var, rhs: Var):
     """Merge two objects.
 
     Args:
@@ -466,10 +481,14 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
     """
     return var_operation_return(
         js_expression=f"({{...{lhs}, ...{rhs}}})",
-        var_type=Dict[
-            Union[lhs._key_type(), rhs._key_type()],
-            Union[lhs._value_type(), rhs._value_type()],
-        ],
+        type_computer=nary_type_computer(
+            ReflexCallable[[Any, Any], Dict[Any, Any]],
+            ReflexCallable[[Any], Dict[Any, Any]],
+            computer=lambda args: Dict[
+                unionize(*[arg.to(ObjectVar)._key_type() for arg in args]),
+                unionize(*[arg.to(ObjectVar)._value_type() for arg in args]),
+            ],
+        ),
     )
 
 
@@ -526,7 +545,7 @@ class ObjectItemOperation(CachedVarOperation, Var):
 
 
 @var_operation
-def object_has_own_property_operation(object: ObjectVar, key: Var):
+def object_has_own_property_operation(object: Var, key: Var):
     """Check if an object has a key.
 
     Args:

+ 124 - 62
reflex/vars/sequence.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import dataclasses
+import functools
 import inspect
 import json
 import re
@@ -34,6 +35,7 @@ from .base import (
     CachedVarOperation,
     CustomVarOperationReturn,
     LiteralVar,
+    ReflexCallable,
     Var,
     VarData,
     _global_vars,
@@ -41,7 +43,10 @@ from .base import (
     figure_out_type,
     get_python_literal,
     get_unique_variable_name,
+    nary_type_computer,
+    passthrough_unary_type_computer,
     unionize,
+    unwrap_reflex_callalbe,
     var_operation,
     var_operation_return,
 )
@@ -353,7 +358,7 @@ class StringVar(Var[STRING_TYPE], python_types=str):
 
 
 @var_operation
-def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
+def string_lt_operation(lhs: Var[str], rhs: Var[str]):
     """Check if a string is less than another string.
 
     Args:
@@ -367,7 +372,7 @@ def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
 
 
 @var_operation
-def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
+def string_gt_operation(lhs: Var[str], rhs: Var[str]):
     """Check if a string is greater than another string.
 
     Args:
@@ -381,7 +386,7 @@ def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
 
 
 @var_operation
-def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
+def string_le_operation(lhs: Var[str], rhs: Var[str]):
     """Check if a string is less than or equal to another string.
 
     Args:
@@ -395,7 +400,7 @@ def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
 
 
 @var_operation
-def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
+def string_ge_operation(lhs: Var[str], rhs: Var[str]):
     """Check if a string is greater than or equal to another string.
 
     Args:
@@ -409,7 +414,7 @@ def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
 
 
 @var_operation
-def string_lower_operation(string: StringVar[Any]):
+def string_lower_operation(string: Var[str]):
     """Convert a string to lowercase.
 
     Args:
@@ -422,7 +427,7 @@ def string_lower_operation(string: StringVar[Any]):
 
 
 @var_operation
-def string_upper_operation(string: StringVar[Any]):
+def string_upper_operation(string: Var[str]):
     """Convert a string to uppercase.
 
     Args:
@@ -435,7 +440,7 @@ def string_upper_operation(string: StringVar[Any]):
 
 
 @var_operation
-def string_strip_operation(string: StringVar[Any]):
+def string_strip_operation(string: Var[str]):
     """Strip a string.
 
     Args:
@@ -449,7 +454,7 @@ def string_strip_operation(string: StringVar[Any]):
 
 @var_operation
 def string_contains_field_operation(
-    haystack: StringVar[Any], needle: StringVar[Any] | str, field: StringVar[Any] | str
+    haystack: Var[str], needle: Var[str], field: Var[str]
 ):
     """Check if a string contains another string.
 
@@ -468,7 +473,7 @@ def string_contains_field_operation(
 
 
 @var_operation
-def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | str):
+def string_contains_operation(haystack: Var[str], needle: Var[str]):
     """Check if a string contains another string.
 
     Args:
@@ -484,9 +489,7 @@ def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] |
 
 
 @var_operation
-def string_starts_with_operation(
-    full_string: StringVar[Any], prefix: StringVar[Any] | str
-):
+def string_starts_with_operation(full_string: Var[str], prefix: Var[str]):
     """Check if a string starts with a prefix.
 
     Args:
@@ -502,7 +505,7 @@ def string_starts_with_operation(
 
 
 @var_operation
-def string_item_operation(string: StringVar[Any], index: NumberVar | int):
+def string_item_operation(string: Var[str], index: Var[int]):
     """Get an item from a string.
 
     Args:
@@ -515,23 +518,9 @@ def string_item_operation(string: StringVar[Any], index: NumberVar | int):
     return var_operation_return(js_expression=f"{string}.at({index})", var_type=str)
 
 
-@var_operation
-def array_join_operation(array: ArrayVar, sep: StringVar[Any] | str = ""):
-    """Join the elements of an array.
-
-    Args:
-        array: The array.
-        sep: The separator.
-
-    Returns:
-        The joined elements.
-    """
-    return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str)
-
-
 @var_operation
 def string_replace_operation(
-    string: StringVar, search_value: StringVar | str, new_value: StringVar | str
+    string: Var[str], search_value: Var[str], new_value: Var[str]
 ):
     """Replace a string with a value.
 
@@ -1046,7 +1035,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
         Returns:
             The array pluck operation.
         """
-        return array_pluck_operation(self, field)
+        return array_pluck_operation(self, field).guess_type()
 
     @overload
     def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: ...
@@ -1300,7 +1289,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
 
 
 @var_operation
-def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""):
+def string_split_operation(string: Var[str], sep: Var[str]):
     """Split a string.
 
     Args:
@@ -1394,9 +1383,9 @@ class ArraySliceOperation(CachedVarOperation, ArrayVar):
 
 @var_operation
 def array_pluck_operation(
-    array: ArrayVar[ARRAY_VAR_TYPE],
-    field: StringVar | str,
-) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
+    array: Var[ARRAY_VAR_TYPE],
+    field: Var[str],
+) -> CustomVarOperationReturn[List]:
     """Pluck a field from an array of objects.
 
     Args:
@@ -1408,13 +1397,27 @@ def array_pluck_operation(
     """
     return var_operation_return(
         js_expression=f"{array}.map(e=>e?.[{field}])",
-        var_type=array._var_type,
+        var_type=List[Any],
     )
 
 
+@var_operation
+def array_join_operation(array: Var[ARRAY_VAR_TYPE], sep: Var[str]):
+    """Join the elements of an array.
+
+    Args:
+        array: The array.
+        sep: The separator.
+
+    Returns:
+        The joined elements.
+    """
+    return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str)
+
+
 @var_operation
 def array_reverse_operation(
-    array: ArrayVar[ARRAY_VAR_TYPE],
+    array: Var[ARRAY_VAR_TYPE],
 ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
     """Reverse an array.
 
@@ -1426,12 +1429,12 @@ def array_reverse_operation(
     """
     return var_operation_return(
         js_expression=f"{array}.slice().reverse()",
-        var_type=array._var_type,
+        type_computer=passthrough_unary_type_computer(ReflexCallable[[List], List]),
     )
 
 
 @var_operation
-def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
+def array_lt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
     """Check if an array is less than another array.
 
     Args:
@@ -1445,7 +1448,7 @@ def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
 
 
 @var_operation
-def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
+def array_gt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
     """Check if an array is greater than another array.
 
     Args:
@@ -1459,7 +1462,7 @@ def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
 
 
 @var_operation
-def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
+def array_le_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
     """Check if an array is less than or equal to another array.
 
     Args:
@@ -1473,7 +1476,7 @@ def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
 
 
 @var_operation
-def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
+def array_ge_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
     """Check if an array is greater than or equal to another array.
 
     Args:
@@ -1487,7 +1490,7 @@ def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
 
 
 @var_operation
-def array_length_operation(array: ArrayVar):
+def array_length_operation(array: Var[ARRAY_VAR_TYPE]):
     """Get the length of an array.
 
     Args:
@@ -1517,7 +1520,7 @@ def is_tuple_type(t: GenericType) -> bool:
 
 
 @var_operation
-def array_item_operation(array: ArrayVar, index: NumberVar | int):
+def array_item_operation(array: Var[ARRAY_VAR_TYPE], index: Var[int]):
     """Get an item from an array.
 
     Args:
@@ -1527,23 +1530,45 @@ def array_item_operation(array: ArrayVar, index: NumberVar | int):
     Returns:
         The item from the array.
     """
-    args = typing.get_args(array._var_type)
-    if args and isinstance(index, LiteralNumberVar) and is_tuple_type(array._var_type):
-        index_value = int(index._var_value)
-        element_type = args[index_value % len(args)]
-    else:
-        element_type = unionize(*args)
+
+    def type_computer(*args):
+        if len(args) == 0:
+            return (
+                ReflexCallable[[List[Any], int], Any],
+                functools.partial(type_computer, *args),
+            )
+
+        array = args[0]
+        array_args = typing.get_args(array._var_type)
+
+        if len(args) == 1:
+            return (
+                ReflexCallable[[int], unionize(*array_args)],
+                functools.partial(type_computer, *args),
+            )
+
+        index = args[1]
+
+        if (
+            array_args
+            and isinstance(index, LiteralNumberVar)
+            and is_tuple_type(array._var_type)
+        ):
+            index_value = int(index._var_value)
+            element_type = array_args[index_value % len(array_args)]
+        else:
+            element_type = unionize(*array_args)
+
+        return (ReflexCallable[[], element_type], None)
 
     return var_operation_return(
         js_expression=f"{str(array)}.at({str(index)})",
-        var_type=element_type,
+        type_computer=type_computer,
     )
 
 
 @var_operation
-def array_range_operation(
-    start: NumberVar | int, stop: NumberVar | int, step: NumberVar | int
-):
+def array_range_operation(start: Var[int], stop: Var[int], step: Var[int]):
     """Create a range of numbers.
 
     Args:
@@ -1562,7 +1587,7 @@ def array_range_operation(
 
 @var_operation
 def array_contains_field_operation(
-    haystack: ArrayVar, needle: Any | Var, field: StringVar | str
+    haystack: Var[ARRAY_VAR_TYPE], needle: Var, field: Var[str]
 ):
     """Check if an array contains an element.
 
@@ -1581,7 +1606,7 @@ def array_contains_field_operation(
 
 
 @var_operation
-def array_contains_operation(haystack: ArrayVar, needle: Any | Var):
+def array_contains_operation(haystack: Var[ARRAY_VAR_TYPE], needle: Var):
     """Check if an array contains an element.
 
     Args:
@@ -1599,7 +1624,7 @@ def array_contains_operation(haystack: ArrayVar, needle: Any | Var):
 
 @var_operation
 def repeat_array_operation(
-    array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int
+    array: Var[ARRAY_VAR_TYPE], count: Var[int]
 ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
     """Repeat an array a number of times.
 
@@ -1610,20 +1635,34 @@ def repeat_array_operation(
     Returns:
         The repeated array.
     """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (
+                ReflexCallable[[List[Any], int], List[Any]],
+                type_computer,
+            )
+        if len(args) == 1:
+            return (
+                ReflexCallable[[int], args[0]._var_type],
+                functools.partial(type_computer, *args),
+            )
+        return (ReflexCallable[[], args[0]._var_type], None)
+
     return var_operation_return(
         js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})",
-        var_type=array._var_type,
+        type_computer=type_computer,
     )
 
 
 if TYPE_CHECKING:
-    from .function import FunctionVar
+    pass
 
 
 @var_operation
 def map_array_operation(
-    array: ArrayVar[ARRAY_VAR_TYPE],
-    function: FunctionVar,
+    array: Var[ARRAY_VAR_TYPE],
+    function: Var[ReflexCallable],
 ):
     """Map a function over an array.
 
@@ -1634,14 +1673,33 @@ def map_array_operation(
     Returns:
         The mapped array.
     """
+
+    def type_computer(*args: Var):
+        if not args:
+            return (
+                ReflexCallable[[List[Any], ReflexCallable], List[Any]],
+                type_computer,
+            )
+        if len(args) == 1:
+            return (
+                ReflexCallable[[ReflexCallable], List[Any]],
+                functools.partial(type_computer, *args),
+            )
+        return (ReflexCallable[[], List[args[0]._var_type]], None)
+
     return var_operation_return(
-        js_expression=f"{array}.map({function})", var_type=List[Any]
+        js_expression=f"{array}.map({function})",
+        type_computer=nary_type_computer(
+            ReflexCallable[[List[Any], ReflexCallable], List[Any]],
+            ReflexCallable[[ReflexCallable], List[Any]],
+            computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]],
+        ),
     )
 
 
 @var_operation
 def array_concat_operation(
-    lhs: ArrayVar[ARRAY_VAR_TYPE], rhs: ArrayVar[ARRAY_VAR_TYPE]
+    lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]
 ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
     """Concatenate two arrays.
 
@@ -1654,7 +1712,11 @@ def array_concat_operation(
     """
     return var_operation_return(
         js_expression=f"[...{lhs}, ...{rhs}]",
-        var_type=Union[lhs._var_type, rhs._var_type],
+        type_computer=nary_type_computer(
+            ReflexCallable[[List[Any], List[Any]], List[Any]],
+            ReflexCallable[[List[Any]], List[Any]],
+            computer=lambda args: unionize(args[0]._var_type, args[1]._var_type),
+        ),
     )
 
 

+ 2 - 2
tests/units/test_var.py

@@ -963,11 +963,11 @@ def test_function_var():
 
 def test_var_operation():
     @var_operation
-    def add(a: Union[NumberVar, int], b: Union[NumberVar, int]):
+    def add(a: Var[int], b: Var[int]):
         return var_operation_return(js_expression=f"({a} + {b})", var_type=int)
 
     assert str(add(1, 2)) == "(1 + 2)"
-    assert str(add(a=4, b=-9)) == "(4 + -9)"
+    assert str(add(4, -9)) == "(4 + -9)"
 
     five = LiteralNumberVar.create(5)
     seven = add(2, five)