فهرست منبع

add typing to function vars (#4372)

* add typing to function vars

* import ParamSpec from typing_extensions

* remove ellipsis as they are not supported in 3.9

* try importing everything from extensions

* special case 3.9

* don't use Any from extensions

* get typevar from extensions
Khaleel Al-Adhami 6 ماه پیش
والد
کامیت
a639f526da
5فایلهای تغییر یافته به همراه309 افزوده شده و 47 حذف شده
  1. 5 4
      reflex/event.py
  2. 2 1
      reflex/utils/telemetry.py
  3. 18 4
      reflex/vars/base.py
  4. 283 37
      reflex/vars/function.py
  5. 1 1
      tests/units/test_var.py

+ 5 - 4
reflex/event.py

@@ -45,6 +45,8 @@ from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.function import (
     ArgsFunctionOperation,
+    ArgsFunctionOperationBuilder,
+    BuilderFunctionVar,
     FunctionArgs,
     FunctionStringVar,
     FunctionVar,
@@ -797,8 +799,7 @@ def scroll_to(elem_id: str, align_to_top: bool | Var[bool] = True) -> EventSpec:
     get_element_by_id = FunctionStringVar.create("document.getElementById")
 
     return run_script(
-        get_element_by_id(elem_id)
-        .call(elem_id)
+        get_element_by_id.call(elem_id)
         .to(ObjectVar)
         .scrollIntoView.to(FunctionVar)
         .call(align_to_top),
@@ -1580,7 +1581,7 @@ class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
         )
 
 
-class EventChainVar(FunctionVar, python_types=EventChain):
+class EventChainVar(BuilderFunctionVar, python_types=EventChain):
     """Base class for event chain vars."""
 
 
@@ -1592,7 +1593,7 @@ class EventChainVar(FunctionVar, python_types=EventChain):
 # Note: LiteralVar is second in the inheritance list allowing it act like a
 # CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
 # _cached_var_name property.
-class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
+class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainVar):
     """A literal event chain var."""
 
     _var_value: EventChain = dataclasses.field(default=None)  # type: ignore

+ 2 - 1
reflex/utils/telemetry.py

@@ -51,7 +51,8 @@ def get_python_version() -> str:
     Returns:
         The Python version.
     """
-    return platform.python_version()
+    # Remove the "+" from the version string in case user is using a pre-release version.
+    return platform.python_version().rstrip("+")
 
 
 def get_reflex_version() -> str:

+ 18 - 4
reflex/vars/base.py

@@ -361,21 +361,29 @@ class Var(Generic[VAR_TYPE]):
         return False
 
     def __init_subclass__(
-        cls, python_types: Tuple[GenericType, ...] | GenericType = types.Unset, **kwargs
+        cls,
+        python_types: Tuple[GenericType, ...] | GenericType = types.Unset(),
+        default_type: GenericType = types.Unset(),
+        **kwargs,
     ):
         """Initialize the subclass.
 
         Args:
             python_types: The python types that the var represents.
+            default_type: The default type of the var. Defaults to the first python type.
             **kwargs: Additional keyword arguments.
         """
         super().__init_subclass__(**kwargs)
 
-        if python_types is not types.Unset:
+        if python_types or default_type:
             python_types = (
-                python_types if isinstance(python_types, tuple) else (python_types,)
+                (python_types if isinstance(python_types, tuple) else (python_types,))
+                if python_types
+                else ()
             )
 
+            default_type = default_type or (python_types[0] if python_types else Any)
+
             @dataclasses.dataclass(
                 eq=False,
                 frozen=True,
@@ -388,7 +396,7 @@ class Var(Generic[VAR_TYPE]):
                     default=Var(_js_expr="null", _var_type=None),
                 )
 
-                _default_var_type: ClassVar[GenericType] = python_types[0]
+                _default_var_type: ClassVar[GenericType] = default_type
 
             ToVarOperation.__name__ = f'To{cls.__name__.removesuffix("Var")}Operation'
 
@@ -588,6 +596,12 @@ class Var(Generic[VAR_TYPE]):
         output: type[list] | type[tuple] | type[set],
     ) -> ArrayVar: ...
 
+    @overload
+    def to(
+        self,
+        output: type[dict],
+    ) -> ObjectVar[dict]: ...
+
     @overload
     def to(
         self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE]

+ 283 - 37
reflex/vars/function.py

@@ -4,32 +4,177 @@ from __future__ import annotations
 
 import dataclasses
 import sys
-from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
+from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
+
+from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
 
 from reflex.utils import format
 from reflex.utils.types import GenericType
 
 from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
 
+P = ParamSpec("P")
+V1 = TypeVar("V1")
+V2 = TypeVar("V2")
+V3 = TypeVar("V3")
+V4 = TypeVar("V4")
+V5 = TypeVar("V5")
+V6 = TypeVar("V6")
+R = TypeVar("R")
+
+
+class ReflexCallable(Protocol[P, R]):
+    """Protocol for a callable."""
+
+    __call__: Callable[P, R]
+
 
-class FunctionVar(Var[Callable], python_types=Callable):
+CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
+OTHER_CALLABLE_TYPE = TypeVar(
+    "OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True
+)
+
+
+class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
     """Base class for immutable function vars."""
 
-    def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
-        """Call the function with the given arguments.
+    @overload
+    def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
+        arg1: Union[V1, Var[V1]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+        arg5: Union[V5, Var[V5]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+        arg5: Union[V5, Var[V5]],
+        arg6: Union[V6, Var[V6]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(self, *args: Var | Any) -> FunctionVar: ...
+
+    def partial(self, *args: Var | Any) -> FunctionVar:  # type: ignore
+        """Partially apply the function with the given arguments.
 
         Args:
-            *args: The arguments to call the function with.
+            *args: The arguments to partially apply the function with.
 
         Returns:
-            The function call operation.
+            The partially applied function.
         """
+        if not args:
+            return ArgsFunctionOperation.create((), self)
         return ArgsFunctionOperation.create(
             ("...args",),
             VarOperationCall.create(self, *args, Var(_js_expr="...args")),
         )
 
-    def call(self, *args: Var | Any) -> VarOperationCall:
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
+    ) -> VarOperationCall[[V1], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+    ) -> VarOperationCall[[V1, V2], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+    ) -> VarOperationCall[[V1, V2, V3], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+    ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+        arg5: Union[V5, Var[V5]],
+    ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+        arg5: Union[V5, Var[V5]],
+        arg6: Union[V6, Var[V6]],
+    ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
+    ) -> VarOperationCall[P, R]: ...
+
+    @overload
+    def call(self, *args: Var | Any) -> Var: ...
+
+    def call(self, *args: Var | Any) -> Var:  # type: ignore
         """Call the function with the given arguments.
 
         Args:
@@ -38,19 +183,29 @@ class FunctionVar(Var[Callable], python_types=Callable):
         Returns:
             The function call operation.
         """
-        return VarOperationCall.create(self, *args)
+        return VarOperationCall.create(self, *args).guess_type()
+
+    __call__ = call
+
+
+class BuilderFunctionVar(
+    FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
+):
+    """Base class for immutable function vars with the builder pattern."""
+
+    __call__ = FunctionVar.partial
 
 
-class FunctionStringVar(FunctionVar):
+class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
     """Base class for immutable function vars from a string."""
 
     @classmethod
     def create(
         cls,
         func: str,
-        _var_type: Type[Callable] = Callable,
+        _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
         _var_data: VarData | None = None,
-    ) -> FunctionStringVar:
+    ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
         """Create a new function var from a string.
 
         Args:
@@ -60,7 +215,7 @@ class FunctionStringVar(FunctionVar):
         Returns:
             The function var.
         """
-        return cls(
+        return FunctionStringVar(
             _js_expr=func,
             _var_type=_var_type,
             _var_data=_var_data,
@@ -72,10 +227,10 @@ class FunctionStringVar(FunctionVar):
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
-class VarOperationCall(CachedVarOperation, Var):
+class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
     """Base class for immutable vars that are the result of a function call."""
 
-    _func: Optional[FunctionVar] = dataclasses.field(default=None)
+    _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
     _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
 
     @cached_property_no_lock
@@ -103,7 +258,7 @@ class VarOperationCall(CachedVarOperation, Var):
     @classmethod
     def create(
         cls,
-        func: FunctionVar,
+        func: FunctionVar[ReflexCallable[P, R]],
         *args: Var | Any,
         _var_type: GenericType = Any,
         _var_data: VarData | None = None,
@@ -118,9 +273,15 @@ class VarOperationCall(CachedVarOperation, Var):
         Returns:
             The function call var.
         """
+        function_return_type = (
+            func._var_type.__args__[1]
+            if getattr(func._var_type, "__args__", None)
+            else Any
+        )
+        var_type = _var_type if _var_type is not Any else function_return_type
         return cls(
             _js_expr="",
-            _var_type=_var_type,
+            _var_type=var_type,
             _var_data=_var_data,
             _func=func,
             _args=args,
@@ -157,6 +318,33 @@ class FunctionArgs:
     rest: Optional[str] = None
 
 
+def format_args_function_operation(
+    args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
+) -> str:
+    """Format an args function operation.
+
+    Args:
+        args: The function arguments.
+        return_expr: The return expression.
+        explicit_return: Whether to use explicit return syntax.
+
+    Returns:
+        The formatted args function operation.
+    """
+    arg_names_str = ", ".join(
+        [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
+    ) + (f", ...{args.rest}" if args.rest else "")
+
+    return_expr_str = str(LiteralVar.create(return_expr))
+
+    # Wrap return expression in curly braces if explicit return syntax is used.
+    return_expr_str_wrapped = (
+        format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
+    )
+
+    return f"(({arg_names_str}) => {return_expr_str_wrapped})"
+
+
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
@@ -176,24 +364,10 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         Returns:
             The name of the var.
         """
-        arg_names_str = ", ".join(
-            [
-                arg if isinstance(arg, str) else arg.to_javascript()
-                for arg in self._args.args
-            ]
-        ) + (f", ...{self._args.rest}" if self._args.rest else "")
-
-        return_expr_str = str(LiteralVar.create(self._return_expr))
-
-        # Wrap return expression in curly braces if explicit return syntax is used.
-        return_expr_str_wrapped = (
-            format.wrap(return_expr_str, "{", "}")
-            if self._explicit_return
-            else return_expr_str
+        return format_args_function_operation(
+            self._args, self._return_expr, self._explicit_return
         )
 
-        return f"(({arg_names_str}) => {return_expr_str_wrapped})"
-
     @classmethod
     def create(
         cls,
@@ -203,7 +377,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         explicit_return: bool = False,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
-    ) -> ArgsFunctionOperation:
+    ):
         """Create a new function var.
 
         Args:
@@ -226,8 +400,80 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         )
 
 
-JSON_STRINGIFY = FunctionStringVar.create("JSON.stringify")
-ARRAY_ISARRAY = FunctionStringVar.create("Array.isArray")
-PROTOTYPE_TO_STRING = FunctionStringVar.create(
-    "((__to_string) => __to_string.toString())"
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
+class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
+    """Base class for immutable function defined via arguments and return expression with the builder pattern."""
+
+    _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
+    _return_expr: Union[Var, Any] = dataclasses.field(default=None)
+    _explicit_return: bool = dataclasses.field(default=False)
+
+    @cached_property_no_lock
+    def _cached_var_name(self) -> str:
+        """The name of the var.
+
+        Returns:
+            The name of the var.
+        """
+        return format_args_function_operation(
+            self._args, self._return_expr, self._explicit_return
+        )
+
+    @classmethod
+    def create(
+        cls,
+        args_names: Sequence[Union[str, DestructuredArg]],
+        return_expr: Var | Any,
+        rest: str | None = None,
+        explicit_return: bool = False,
+        _var_type: GenericType = Callable,
+        _var_data: VarData | None = None,
+    ):
+        """Create a new function var.
+
+        Args:
+            args_names: The names of the arguments.
+            return_expr: The return expression of the function.
+            rest: The name of the rest argument.
+            explicit_return: Whether to use explicit return syntax.
+            _var_data: Additional hooks and imports associated with the Var.
+
+        Returns:
+            The function var.
+        """
+        return cls(
+            _js_expr="",
+            _var_type=_var_type,
+            _var_data=_var_data,
+            _args=FunctionArgs(args=tuple(args_names), rest=rest),
+            _return_expr=return_expr,
+            _explicit_return=explicit_return,
+        )
+
+
+if python_version := sys.version_info[:2] >= (3, 10):
+    JSON_STRINGIFY = FunctionStringVar.create(
+        "JSON.stringify", _var_type=ReflexCallable[[Any], str]
+    )
+    ARRAY_ISARRAY = FunctionStringVar.create(
+        "Array.isArray", _var_type=ReflexCallable[[Any], bool]
+    )
+    PROTOTYPE_TO_STRING = FunctionStringVar.create(
+        "((__to_string) => __to_string.toString())",
+        _var_type=ReflexCallable[[Any], str],
+    )
+else:
+    JSON_STRINGIFY = FunctionStringVar.create(
+        "JSON.stringify", _var_type=ReflexCallable[Any, str]
+    )
+    ARRAY_ISARRAY = FunctionStringVar.create(
+        "Array.isArray", _var_type=ReflexCallable[Any, bool]
+    )
+    PROTOTYPE_TO_STRING = FunctionStringVar.create(
+        "((__to_string) => __to_string.toString())",
+        _var_type=ReflexCallable[Any, str],
+    )

+ 1 - 1
tests/units/test_var.py

@@ -928,7 +928,7 @@ def test_function_var():
         == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
     )
 
-    increment_func = addition_func(1)
+    increment_func = addition_func.partial(1)
     assert (
         str(increment_func.call(2))
         == "(((...args) => (((a, b) => a + b)(1, ...args)))(2))"