Ver código fonte

add typing to function vars (#4372)

* add typing to function vars

* import ParamSpec from typing_extensions

* remove ellipsis as they are not supported in 3.9

* try importing everything from extensions

* special case 3.9

* don't use Any from extensions

* get typevar from extensions
Khaleel Al-Adhami 6 meses atrás
pai
commit
a639f526da
5 arquivos alterados com 309 adições e 47 exclusões
  1. 5 4
      reflex/event.py
  2. 2 1
      reflex/utils/telemetry.py
  3. 18 4
      reflex/vars/base.py
  4. 283 37
      reflex/vars/function.py
  5. 1 1
      tests/units/test_var.py

+ 5 - 4
reflex/event.py

@@ -45,6 +45,8 @@ from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.function import (
 from reflex.vars.function import (
     ArgsFunctionOperation,
     ArgsFunctionOperation,
+    ArgsFunctionOperationBuilder,
+    BuilderFunctionVar,
     FunctionArgs,
     FunctionArgs,
     FunctionStringVar,
     FunctionStringVar,
     FunctionVar,
     FunctionVar,
@@ -797,8 +799,7 @@ def scroll_to(elem_id: str, align_to_top: bool | Var[bool] = True) -> EventSpec:
     get_element_by_id = FunctionStringVar.create("document.getElementById")
     get_element_by_id = FunctionStringVar.create("document.getElementById")
 
 
     return run_script(
     return run_script(
-        get_element_by_id(elem_id)
-        .call(elem_id)
+        get_element_by_id.call(elem_id)
         .to(ObjectVar)
         .to(ObjectVar)
         .scrollIntoView.to(FunctionVar)
         .scrollIntoView.to(FunctionVar)
         .call(align_to_top),
         .call(align_to_top),
@@ -1580,7 +1581,7 @@ class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
         )
         )
 
 
 
 
-class EventChainVar(FunctionVar, python_types=EventChain):
+class EventChainVar(BuilderFunctionVar, python_types=EventChain):
     """Base class for event chain vars."""
     """Base class for event chain vars."""
 
 
 
 
@@ -1592,7 +1593,7 @@ class EventChainVar(FunctionVar, python_types=EventChain):
 # Note: LiteralVar is second in the inheritance list allowing it act like a
 # Note: LiteralVar is second in the inheritance list allowing it act like a
 # CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
 # CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
 # _cached_var_name property.
 # _cached_var_name property.
-class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
+class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainVar):
     """A literal event chain var."""
     """A literal event chain var."""
 
 
     _var_value: EventChain = dataclasses.field(default=None)  # type: ignore
     _var_value: EventChain = dataclasses.field(default=None)  # type: ignore

+ 2 - 1
reflex/utils/telemetry.py

@@ -51,7 +51,8 @@ def get_python_version() -> str:
     Returns:
     Returns:
         The Python version.
         The Python version.
     """
     """
-    return platform.python_version()
+    # Remove the "+" from the version string in case user is using a pre-release version.
+    return platform.python_version().rstrip("+")
 
 
 
 
 def get_reflex_version() -> str:
 def get_reflex_version() -> str:

+ 18 - 4
reflex/vars/base.py

@@ -361,21 +361,29 @@ class Var(Generic[VAR_TYPE]):
         return False
         return False
 
 
     def __init_subclass__(
     def __init_subclass__(
-        cls, python_types: Tuple[GenericType, ...] | GenericType = types.Unset, **kwargs
+        cls,
+        python_types: Tuple[GenericType, ...] | GenericType = types.Unset(),
+        default_type: GenericType = types.Unset(),
+        **kwargs,
     ):
     ):
         """Initialize the subclass.
         """Initialize the subclass.
 
 
         Args:
         Args:
             python_types: The python types that the var represents.
             python_types: The python types that the var represents.
+            default_type: The default type of the var. Defaults to the first python type.
             **kwargs: Additional keyword arguments.
             **kwargs: Additional keyword arguments.
         """
         """
         super().__init_subclass__(**kwargs)
         super().__init_subclass__(**kwargs)
 
 
-        if python_types is not types.Unset:
+        if python_types or default_type:
             python_types = (
             python_types = (
-                python_types if isinstance(python_types, tuple) else (python_types,)
+                (python_types if isinstance(python_types, tuple) else (python_types,))
+                if python_types
+                else ()
             )
             )
 
 
+            default_type = default_type or (python_types[0] if python_types else Any)
+
             @dataclasses.dataclass(
             @dataclasses.dataclass(
                 eq=False,
                 eq=False,
                 frozen=True,
                 frozen=True,
@@ -388,7 +396,7 @@ class Var(Generic[VAR_TYPE]):
                     default=Var(_js_expr="null", _var_type=None),
                     default=Var(_js_expr="null", _var_type=None),
                 )
                 )
 
 
-                _default_var_type: ClassVar[GenericType] = python_types[0]
+                _default_var_type: ClassVar[GenericType] = default_type
 
 
             ToVarOperation.__name__ = f'To{cls.__name__.removesuffix("Var")}Operation'
             ToVarOperation.__name__ = f'To{cls.__name__.removesuffix("Var")}Operation'
 
 
@@ -588,6 +596,12 @@ class Var(Generic[VAR_TYPE]):
         output: type[list] | type[tuple] | type[set],
         output: type[list] | type[tuple] | type[set],
     ) -> ArrayVar: ...
     ) -> ArrayVar: ...
 
 
+    @overload
+    def to(
+        self,
+        output: type[dict],
+    ) -> ObjectVar[dict]: ...
+
     @overload
     @overload
     def to(
     def to(
         self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE]
         self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE]

+ 283 - 37
reflex/vars/function.py

@@ -4,32 +4,177 @@ from __future__ import annotations
 
 
 import dataclasses
 import dataclasses
 import sys
 import sys
-from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
+from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
+
+from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
 
 
 from reflex.utils import format
 from reflex.utils import format
 from reflex.utils.types import GenericType
 from reflex.utils.types import GenericType
 
 
 from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
 from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
 
 
+P = ParamSpec("P")
+V1 = TypeVar("V1")
+V2 = TypeVar("V2")
+V3 = TypeVar("V3")
+V4 = TypeVar("V4")
+V5 = TypeVar("V5")
+V6 = TypeVar("V6")
+R = TypeVar("R")
+
+
+class ReflexCallable(Protocol[P, R]):
+    """Protocol for a callable."""
+
+    __call__: Callable[P, R]
+
 
 
-class FunctionVar(Var[Callable], python_types=Callable):
+CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
+OTHER_CALLABLE_TYPE = TypeVar(
+    "OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True
+)
+
+
+class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
     """Base class for immutable function vars."""
     """Base class for immutable function vars."""
 
 
-    def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
-        """Call the function with the given arguments.
+    @overload
+    def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
+        arg1: Union[V1, Var[V1]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+        arg5: Union[V5, Var[V5]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+        arg5: Union[V5, Var[V5]],
+        arg6: Union[V6, Var[V6]],
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(
+        self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
+    ) -> FunctionVar[ReflexCallable[P, R]]: ...
+
+    @overload
+    def partial(self, *args: Var | Any) -> FunctionVar: ...
+
+    def partial(self, *args: Var | Any) -> FunctionVar:  # type: ignore
+        """Partially apply the function with the given arguments.
 
 
         Args:
         Args:
-            *args: The arguments to call the function with.
+            *args: The arguments to partially apply the function with.
 
 
         Returns:
         Returns:
-            The function call operation.
+            The partially applied function.
         """
         """
+        if not args:
+            return ArgsFunctionOperation.create((), self)
         return ArgsFunctionOperation.create(
         return ArgsFunctionOperation.create(
             ("...args",),
             ("...args",),
             VarOperationCall.create(self, *args, Var(_js_expr="...args")),
             VarOperationCall.create(self, *args, Var(_js_expr="...args")),
         )
         )
 
 
-    def call(self, *args: Var | Any) -> VarOperationCall:
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
+    ) -> VarOperationCall[[V1], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+    ) -> VarOperationCall[[V1, V2], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+    ) -> VarOperationCall[[V1, V2, V3], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+    ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+        arg5: Union[V5, Var[V5]],
+    ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
+        arg1: Union[V1, Var[V1]],
+        arg2: Union[V2, Var[V2]],
+        arg3: Union[V3, Var[V3]],
+        arg4: Union[V4, Var[V4]],
+        arg5: Union[V5, Var[V5]],
+        arg6: Union[V6, Var[V6]],
+    ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
+
+    @overload
+    def call(
+        self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
+    ) -> VarOperationCall[P, R]: ...
+
+    @overload
+    def call(self, *args: Var | Any) -> Var: ...
+
+    def call(self, *args: Var | Any) -> Var:  # type: ignore
         """Call the function with the given arguments.
         """Call the function with the given arguments.
 
 
         Args:
         Args:
@@ -38,19 +183,29 @@ class FunctionVar(Var[Callable], python_types=Callable):
         Returns:
         Returns:
             The function call operation.
             The function call operation.
         """
         """
-        return VarOperationCall.create(self, *args)
+        return VarOperationCall.create(self, *args).guess_type()
+
+    __call__ = call
+
+
+class BuilderFunctionVar(
+    FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
+):
+    """Base class for immutable function vars with the builder pattern."""
+
+    __call__ = FunctionVar.partial
 
 
 
 
-class FunctionStringVar(FunctionVar):
+class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
     """Base class for immutable function vars from a string."""
     """Base class for immutable function vars from a string."""
 
 
     @classmethod
     @classmethod
     def create(
     def create(
         cls,
         cls,
         func: str,
         func: str,
-        _var_type: Type[Callable] = Callable,
+        _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
-    ) -> FunctionStringVar:
+    ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
         """Create a new function var from a string.
         """Create a new function var from a string.
 
 
         Args:
         Args:
@@ -60,7 +215,7 @@ class FunctionStringVar(FunctionVar):
         Returns:
         Returns:
             The function var.
             The function var.
         """
         """
-        return cls(
+        return FunctionStringVar(
             _js_expr=func,
             _js_expr=func,
             _var_type=_var_type,
             _var_type=_var_type,
             _var_data=_var_data,
             _var_data=_var_data,
@@ -72,10 +227,10 @@ class FunctionStringVar(FunctionVar):
     frozen=True,
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
 )
-class VarOperationCall(CachedVarOperation, Var):
+class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
     """Base class for immutable vars that are the result of a function call."""
     """Base class for immutable vars that are the result of a function call."""
 
 
-    _func: Optional[FunctionVar] = dataclasses.field(default=None)
+    _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
     _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
     _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
 
 
     @cached_property_no_lock
     @cached_property_no_lock
@@ -103,7 +258,7 @@ class VarOperationCall(CachedVarOperation, Var):
     @classmethod
     @classmethod
     def create(
     def create(
         cls,
         cls,
-        func: FunctionVar,
+        func: FunctionVar[ReflexCallable[P, R]],
         *args: Var | Any,
         *args: Var | Any,
         _var_type: GenericType = Any,
         _var_type: GenericType = Any,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
@@ -118,9 +273,15 @@ class VarOperationCall(CachedVarOperation, Var):
         Returns:
         Returns:
             The function call var.
             The function call var.
         """
         """
+        function_return_type = (
+            func._var_type.__args__[1]
+            if getattr(func._var_type, "__args__", None)
+            else Any
+        )
+        var_type = _var_type if _var_type is not Any else function_return_type
         return cls(
         return cls(
             _js_expr="",
             _js_expr="",
-            _var_type=_var_type,
+            _var_type=var_type,
             _var_data=_var_data,
             _var_data=_var_data,
             _func=func,
             _func=func,
             _args=args,
             _args=args,
@@ -157,6 +318,33 @@ class FunctionArgs:
     rest: Optional[str] = None
     rest: Optional[str] = None
 
 
 
 
+def format_args_function_operation(
+    args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
+) -> str:
+    """Format an args function operation.
+
+    Args:
+        args: The function arguments.
+        return_expr: The return expression.
+        explicit_return: Whether to use explicit return syntax.
+
+    Returns:
+        The formatted args function operation.
+    """
+    arg_names_str = ", ".join(
+        [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
+    ) + (f", ...{args.rest}" if args.rest else "")
+
+    return_expr_str = str(LiteralVar.create(return_expr))
+
+    # Wrap return expression in curly braces if explicit return syntax is used.
+    return_expr_str_wrapped = (
+        format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
+    )
+
+    return f"(({arg_names_str}) => {return_expr_str_wrapped})"
+
+
 @dataclasses.dataclass(
 @dataclasses.dataclass(
     eq=False,
     eq=False,
     frozen=True,
     frozen=True,
@@ -176,24 +364,10 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         Returns:
         Returns:
             The name of the var.
             The name of the var.
         """
         """
-        arg_names_str = ", ".join(
-            [
-                arg if isinstance(arg, str) else arg.to_javascript()
-                for arg in self._args.args
-            ]
-        ) + (f", ...{self._args.rest}" if self._args.rest else "")
-
-        return_expr_str = str(LiteralVar.create(self._return_expr))
-
-        # Wrap return expression in curly braces if explicit return syntax is used.
-        return_expr_str_wrapped = (
-            format.wrap(return_expr_str, "{", "}")
-            if self._explicit_return
-            else return_expr_str
+        return format_args_function_operation(
+            self._args, self._return_expr, self._explicit_return
         )
         )
 
 
-        return f"(({arg_names_str}) => {return_expr_str_wrapped})"
-
     @classmethod
     @classmethod
     def create(
     def create(
         cls,
         cls,
@@ -203,7 +377,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         explicit_return: bool = False,
         explicit_return: bool = False,
         _var_type: GenericType = Callable,
         _var_type: GenericType = Callable,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
-    ) -> ArgsFunctionOperation:
+    ):
         """Create a new function var.
         """Create a new function var.
 
 
         Args:
         Args:
@@ -226,8 +400,80 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
         )
         )
 
 
 
 
-JSON_STRINGIFY = FunctionStringVar.create("JSON.stringify")
-ARRAY_ISARRAY = FunctionStringVar.create("Array.isArray")
-PROTOTYPE_TO_STRING = FunctionStringVar.create(
-    "((__to_string) => __to_string.toString())"
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
 )
+class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
+    """Base class for immutable function defined via arguments and return expression with the builder pattern."""
+
+    _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
+    _return_expr: Union[Var, Any] = dataclasses.field(default=None)
+    _explicit_return: bool = dataclasses.field(default=False)
+
+    @cached_property_no_lock
+    def _cached_var_name(self) -> str:
+        """The name of the var.
+
+        Returns:
+            The name of the var.
+        """
+        return format_args_function_operation(
+            self._args, self._return_expr, self._explicit_return
+        )
+
+    @classmethod
+    def create(
+        cls,
+        args_names: Sequence[Union[str, DestructuredArg]],
+        return_expr: Var | Any,
+        rest: str | None = None,
+        explicit_return: bool = False,
+        _var_type: GenericType = Callable,
+        _var_data: VarData | None = None,
+    ):
+        """Create a new function var.
+
+        Args:
+            args_names: The names of the arguments.
+            return_expr: The return expression of the function.
+            rest: The name of the rest argument.
+            explicit_return: Whether to use explicit return syntax.
+            _var_data: Additional hooks and imports associated with the Var.
+
+        Returns:
+            The function var.
+        """
+        return cls(
+            _js_expr="",
+            _var_type=_var_type,
+            _var_data=_var_data,
+            _args=FunctionArgs(args=tuple(args_names), rest=rest),
+            _return_expr=return_expr,
+            _explicit_return=explicit_return,
+        )
+
+
+if python_version := sys.version_info[:2] >= (3, 10):
+    JSON_STRINGIFY = FunctionStringVar.create(
+        "JSON.stringify", _var_type=ReflexCallable[[Any], str]
+    )
+    ARRAY_ISARRAY = FunctionStringVar.create(
+        "Array.isArray", _var_type=ReflexCallable[[Any], bool]
+    )
+    PROTOTYPE_TO_STRING = FunctionStringVar.create(
+        "((__to_string) => __to_string.toString())",
+        _var_type=ReflexCallable[[Any], str],
+    )
+else:
+    JSON_STRINGIFY = FunctionStringVar.create(
+        "JSON.stringify", _var_type=ReflexCallable[Any, str]
+    )
+    ARRAY_ISARRAY = FunctionStringVar.create(
+        "Array.isArray", _var_type=ReflexCallable[Any, bool]
+    )
+    PROTOTYPE_TO_STRING = FunctionStringVar.create(
+        "((__to_string) => __to_string.toString())",
+        _var_type=ReflexCallable[Any, str],
+    )

+ 1 - 1
tests/units/test_var.py

@@ -928,7 +928,7 @@ def test_function_var():
         == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
         == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
     )
     )
 
 
-    increment_func = addition_func(1)
+    increment_func = addition_func.partial(1)
     assert (
     assert (
         str(increment_func.call(2))
         str(increment_func.call(2))
         == "(((...args) => (((a, b) => a + b)(1, ...args)))(2))"
         == "(((...args) => (((a, b) => a + b)(1, ...args)))(2))"