浏览代码

[REF-3227] implement more literal vars (#3687)

* implement more literal vars

* fix super issue

* pyright has a bug i think

* oh we changed that

* fix docs

* literalize vars recursively

* do what masen told me :D

* use dynamic keys

* forgot .create

* adjust _var_value

* dang it darglint

* add test for serializing literal vars into js exprs

* fix silly mistake

* add  handling for var and none

* use create safe

* is none bruh

* implement function vars and do various modification

* fix None issue

* clear a lot of creates that did nothing

* add tests to function vars

* added simple fix smh

* use fconcat to make an even more complicated test
Khaleel Al-Adhami 10 月之前
父节点
当前提交
ea016314b0
共有 3 个文件被更改,包括 594 次插入32 次删除
  1. 6 0
      reflex/experimental/vars/__init__.py
  2. 532 31
      reflex/experimental/vars/base.py
  3. 56 1
      tests/test_var.py

+ 6 - 0
reflex/experimental/vars/__init__.py

@@ -3,10 +3,16 @@
 from .base import ArrayVar as ArrayVar
 from .base import BooleanVar as BooleanVar
 from .base import ConcatVarOperation as ConcatVarOperation
+from .base import FunctionStringVar as FunctionStringVar
 from .base import FunctionVar as FunctionVar
 from .base import ImmutableVar as ImmutableVar
+from .base import LiteralArrayVar as LiteralArrayVar
+from .base import LiteralBooleanVar as LiteralBooleanVar
+from .base import LiteralNumberVar as LiteralNumberVar
+from .base import LiteralObjectVar as LiteralObjectVar
 from .base import LiteralStringVar as LiteralStringVar
 from .base import LiteralVar as LiteralVar
 from .base import NumberVar as NumberVar
 from .base import ObjectVar as ObjectVar
 from .base import StringVar as StringVar
+from .base import VarOperationCall as VarOperationCall

+ 532 - 31
reflex/experimental/vars/base.py

@@ -7,9 +7,10 @@ import json
 import re
 import sys
 from functools import cached_property
-from typing import Any, Optional, Type
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
 
 from reflex import constants
+from reflex.base import Base
 from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
 from reflex.utils import serializers, types
 from reflex.utils.exceptions import VarTypeError
@@ -95,6 +96,11 @@ class ImmutableVar(Var):
         return hash((self._var_name, self._var_type, self._var_data))
 
     def _get_all_var_data(self) -> ImmutableVarData | None:
+        """Get all VarData associated with the Var.
+
+        Returns:
+            The VarData of the components and all of its children.
+        """
         return self._var_data
 
     def _replace(self, merge_var_data=None, **kwargs: Any):
@@ -275,10 +281,250 @@ class ArrayVar(ImmutableVar):
 class FunctionVar(ImmutableVar):
     """Base class for immutable function vars."""
 
+    def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
+        """Call the function with the given arguments.
+
+        Args:
+            *args: The arguments to call the function with.
+
+        Returns:
+            The function call operation.
+        """
+        return ArgsFunctionOperation(
+            ("...args",),
+            VarOperationCall(self, *args, ImmutableVar.create_safe("...args")),
+        )
+
+    def call(self, *args: Var | Any) -> VarOperationCall:
+        """Call the function with the given arguments.
+
+        Args:
+            *args: The arguments to call the function with.
+
+        Returns:
+            The function call operation.
+        """
+        return VarOperationCall(self, *args)
+
+
+class FunctionStringVar(FunctionVar):
+    """Base class for immutable function vars from a string."""
+
+    def __init__(self, func: str, _var_data: VarData | None = None) -> None:
+        """Initialize the function var.
+
+        Args:
+            func: The function to call.
+            _var_data: Additional hooks and imports associated with the Var.
+        """
+        super(FunctionVar, self).__init__(
+            _var_name=func,
+            _var_type=Callable,
+            _var_data=ImmutableVarData.merge(_var_data),
+        )
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class VarOperationCall(ImmutableVar):
+    """Base class for immutable vars that are the result of a function call."""
+
+    _func: Optional[FunctionVar] = dataclasses.field(default=None)
+    _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
+
+    def __init__(
+        self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None
+    ):
+        """Initialize the function call var.
+
+        Args:
+            func: The function to call.
+            *args: The arguments to call the function with.
+            _var_data: Additional hooks and imports associated with the Var.
+        """
+        super(VarOperationCall, self).__init__(
+            _var_name="",
+            _var_type=Callable,
+            _var_data=ImmutableVarData.merge(_var_data),
+        )
+        object.__setattr__(self, "_func", func)
+        object.__setattr__(self, "_args", args)
+        object.__delattr__(self, "_var_name")
+
+    def __getattr__(self, name):
+        """Get an attribute of the var.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The attribute of the var.
+        """
+        if name == "_var_name":
+            return self._cached_var_name
+        return super(type(self), self).__getattr__(name)
+
+    @cached_property
+    def _cached_var_name(self) -> str:
+        """The name of the var.
+
+        Returns:
+            The name of the var.
+        """
+        return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
+
+    @cached_property
+    def _cached_get_all_var_data(self) -> ImmutableVarData | None:
+        """Get all VarData associated with the Var.
+
+        Returns:
+            The VarData of the components and all of its children.
+        """
+        return ImmutableVarData.merge(
+            self._func._get_all_var_data() if self._func is not None else None,
+            *[var._get_all_var_data() for var in self._args],
+            self._var_data,
+        )
+
+    def _get_all_var_data(self) -> ImmutableVarData | None:
+        """Wrapper method for cached property.
+
+        Returns:
+            The VarData of the components and all of its children.
+        """
+        return self._cached_get_all_var_data
+
+    def __post_init__(self):
+        """Post-initialize the var."""
+        pass
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class ArgsFunctionOperation(FunctionVar):
+    """Base class for immutable function defined via arguments and return expression."""
+
+    _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
+    _return_expr: Union[Var, Any] = dataclasses.field(default=None)
+
+    def __init__(
+        self,
+        args_names: Tuple[str, ...],
+        return_expr: Var | Any,
+        _var_data: VarData | None = None,
+    ) -> None:
+        """Initialize the function with arguments var.
+
+        Args:
+            args_names: The names of the arguments.
+            return_expr: The return expression of the function.
+            _var_data: Additional hooks and imports associated with the Var.
+        """
+        super(ArgsFunctionOperation, self).__init__(
+            _var_name=f"",
+            _var_type=Callable,
+            _var_data=ImmutableVarData.merge(_var_data),
+        )
+        object.__setattr__(self, "_args_names", args_names)
+        object.__setattr__(self, "_return_expr", return_expr)
+        object.__delattr__(self, "_var_name")
+
+    def __getattr__(self, name):
+        """Get an attribute of the var.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The attribute of the var.
+        """
+        if name == "_var_name":
+            return self._cached_var_name
+        return super(type(self), self).__getattr__(name)
+
+    @cached_property
+    def _cached_var_name(self) -> str:
+        """The name of the var.
+
+        Returns:
+            The name of the var.
+        """
+        return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
+
+    @cached_property
+    def _cached_get_all_var_data(self) -> ImmutableVarData | None:
+        """Get all VarData associated with the Var.
+
+        Returns:
+            The VarData of the components and all of its children.
+        """
+        return ImmutableVarData.merge(
+            self._return_expr._get_all_var_data(),
+            self._var_data,
+        )
+
+    def _get_all_var_data(self) -> ImmutableVarData | None:
+        """Wrapper method for cached property.
+
+        Returns:
+            The VarData of the components and all of its children.
+        """
+        return self._cached_get_all_var_data
+
+    def __post_init__(self):
+        """Post-initialize the var."""
+
 
 class LiteralVar(ImmutableVar):
     """Base class for immutable literal vars."""
 
+    @classmethod
+    def create(
+        cls,
+        value: Any,
+        _var_data: VarData | None = None,
+    ) -> Var:
+        """Create a var from a value.
+
+        Args:
+            value: The value to create the var from.
+            _var_data: Additional hooks and imports associated with the Var.
+
+        Returns:
+            The var.
+
+        Raises:
+            TypeError: If the value is not a supported type for LiteralVar.
+        """
+        if isinstance(value, Var):
+            if _var_data is None:
+                return value
+            return value._replace(merge_var_data=_var_data)
+
+        if value is None:
+            return ImmutableVar.create_safe("null", _var_data=_var_data)
+
+        if isinstance(value, Base):
+            return LiteralObjectVar(
+                value.dict(), _var_type=type(value), _var_data=_var_data
+            )
+
+        if isinstance(value, str):
+            return LiteralStringVar.create(value, _var_data=_var_data)
+
+        constructor = type_mapping.get(type(value))
+
+        if constructor is None:
+            raise TypeError(f"Unsupported type {type(value)} for LiteralVar.")
+
+        return constructor(value, _var_data=_var_data)
+
     def __post_init__(self):
         """Post-initialize the var."""
 
@@ -298,7 +544,25 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
 class LiteralStringVar(LiteralVar):
     """Base class for immutable literal string vars."""
 
-    _var_value: Optional[str] = dataclasses.field(default=None)
+    _var_value: str = dataclasses.field(default="")
+
+    def __init__(
+        self,
+        _var_value: str,
+        _var_data: VarData | None = None,
+    ):
+        """Initialize the string var.
+
+        Args:
+            _var_value: The value of the var.
+            _var_data: Additional hooks and imports associated with the Var.
+        """
+        super(LiteralStringVar, self).__init__(
+            _var_name=f'"{_var_value}"',
+            _var_type=str,
+            _var_data=ImmutableVarData.merge(_var_data),
+        )
+        object.__setattr__(self, "_var_value", _var_value)
 
     @classmethod
     def create(
@@ -316,7 +580,7 @@ class LiteralStringVar(LiteralVar):
             The var.
         """
         if REFLEX_VAR_OPENING_TAG in value:
-            strings_and_vals: list[Var] = []
+            strings_and_vals: list[Var | str] = []
             offset = 0
 
             # Initialize some methods for reading json.
@@ -334,7 +598,7 @@ class LiteralStringVar(LiteralVar):
             while m := _decode_var_pattern.search(value):
                 start, end = m.span()
                 if start > 0:
-                    strings_and_vals.append(LiteralStringVar.create(value[:start]))
+                    strings_and_vals.append(value[:start])
 
                 serialized_data = m.group(1)
 
@@ -364,17 +628,13 @@ class LiteralStringVar(LiteralVar):
                 offset += end - start
 
             if value:
-                strings_and_vals.append(LiteralStringVar.create(value))
+                strings_and_vals.append(value)
 
-            return ConcatVarOperation.create(
-                tuple(strings_and_vals), _var_data=_var_data
-            )
+            return ConcatVarOperation(*strings_and_vals, _var_data=_var_data)
 
-        return cls(
-            _var_value=value,
-            _var_name=f'"{value}"',
-            _var_type=str,
-            _var_data=ImmutableVarData.merge(_var_data),
+        return LiteralStringVar(
+            value,
+            _var_data=_var_data,
         )
 
 
@@ -386,20 +646,33 @@ class LiteralStringVar(LiteralVar):
 class ConcatVarOperation(StringVar):
     """Representing a concatenation of literal string vars."""
 
-    _var_value: tuple[Var, ...] = dataclasses.field(default_factory=tuple)
+    _var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple)
 
-    def __init__(self, _var_value: tuple[Var, ...], _var_data: VarData | None = None):
+    def __init__(self, *value: Var | str, _var_data: VarData | None = None):
         """Initialize the operation of concatenating literal string vars.
 
         Args:
-            _var_value: The list of vars to concatenate.
+            value: The values to concatenate.
             _var_data: Additional hooks and imports associated with the Var.
         """
         super(ConcatVarOperation, self).__init__(
             _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str
         )
-        object.__setattr__(self, "_var_value", _var_value)
-        object.__setattr__(self, "_var_name", self._cached_var_name)
+        object.__setattr__(self, "_var_value", value)
+        object.__delattr__(self, "_var_name")
+
+    def __getattr__(self, name):
+        """Get an attribute of the var.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The attribute of the var.
+        """
+        if name == "_var_name":
+            return self._cached_var_name
+        return super(type(self), self).__getattr__(name)
 
     @cached_property
     def _cached_var_name(self) -> str:
@@ -408,7 +681,16 @@ class ConcatVarOperation(StringVar):
         Returns:
             The name of the var.
         """
-        return "+".join([str(element) for element in self._var_value])
+        return (
+            "("
+            + "+".join(
+                [
+                    str(element) if isinstance(element, Var) else f'"{element}"'
+                    for element in self._var_value
+                ]
+            )
+            + ")"
+        )
 
     @cached_property
     def _cached_get_all_var_data(self) -> ImmutableVarData | None:
@@ -418,7 +700,12 @@ class ConcatVarOperation(StringVar):
             The VarData of the components and all of its children.
         """
         return ImmutableVarData.merge(
-            *[var._get_all_var_data() for var in self._var_value], self._var_data
+            *[
+                var._get_all_var_data()
+                for var in self._var_value
+                if isinstance(var, Var)
+            ],
+            self._var_data,
         )
 
     def _get_all_var_data(self) -> ImmutableVarData | None:
@@ -433,22 +720,236 @@ class ConcatVarOperation(StringVar):
         """Post-initialize the var."""
         pass
 
-    @classmethod
-    def create(
-        cls,
-        value: tuple[Var, ...],
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class LiteralBooleanVar(LiteralVar):
+    """Base class for immutable literal boolean vars."""
+
+    _var_value: bool = dataclasses.field(default=False)
+
+    def __init__(
+        self,
+        _var_value: bool,
         _var_data: VarData | None = None,
-    ) -> ConcatVarOperation:
-        """Create a var from a tuple of values.
+    ):
+        """Initialize the boolean var.
 
         Args:
-            value: The value to create the var from.
+            _var_value: The value of the var.
             _var_data: Additional hooks and imports associated with the Var.
+        """
+        super(LiteralBooleanVar, self).__init__(
+            _var_name="true" if _var_value else "false",
+            _var_type=bool,
+            _var_data=ImmutableVarData.merge(_var_data),
+        )
+        object.__setattr__(self, "_var_value", _var_value)
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class LiteralNumberVar(LiteralVar):
+    """Base class for immutable literal number vars."""
+
+    _var_value: float | int = dataclasses.field(default=0)
+
+    def __init__(
+        self,
+        _var_value: float | int,
+        _var_data: VarData | None = None,
+    ):
+        """Initialize the number var.
+
+        Args:
+            _var_value: The value of the var.
+            _var_data: Additional hooks and imports associated with the Var.
+        """
+        super(LiteralNumberVar, self).__init__(
+            _var_name=str(_var_value),
+            _var_type=type(_var_value),
+            _var_data=ImmutableVarData.merge(_var_data),
+        )
+        object.__setattr__(self, "_var_value", _var_value)
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class LiteralObjectVar(LiteralVar):
+    """Base class for immutable literal object vars."""
+
+    _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
+        default_factory=dict
+    )
+
+    def __init__(
+        self,
+        _var_value: dict[Var | Any, Var | Any],
+        _var_type: Type = dict,
+        _var_data: VarData | None = None,
+    ):
+        """Initialize the object var.
+
+        Args:
+            _var_value: The value of the var.
+            _var_data: Additional hooks and imports associated with the Var.
+        """
+        super(LiteralObjectVar, self).__init__(
+            _var_name="",
+            _var_type=_var_type,
+            _var_data=ImmutableVarData.merge(_var_data),
+        )
+        object.__setattr__(
+            self,
+            "_var_value",
+            _var_value,
+        )
+        object.__delattr__(self, "_var_name")
+
+    def __getattr__(self, name):
+        """Get an attribute of the var.
+
+        Args:
+            name: The name of the attribute.
 
         Returns:
-            The var.
+            The attribute of the var.
         """
-        return ConcatVarOperation(
-            _var_value=value,
-            _var_data=_var_data,
+        if name == "_var_name":
+            return self._cached_var_name
+        return super(type(self), self).__getattr__(name)
+
+    @cached_property
+    def _cached_var_name(self) -> str:
+        """The name of the var.
+
+        Returns:
+            The name of the var.
+        """
+        return (
+            "{ "
+            + ", ".join(
+                [
+                    f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}"
+                    for key, value in self._var_value.items()
+                ]
+            )
+            + " }"
+        )
+
+    @cached_property
+    def _get_all_var_data(self) -> ImmutableVarData | None:
+        """Get all VarData associated with the Var.
+
+        Returns:
+            The VarData of the components and all of its children.
+        """
+        return ImmutableVarData.merge(
+            *[
+                value._get_all_var_data()
+                for key, value in self._var_value
+                if isinstance(value, Var)
+            ],
+            *[
+                key._get_all_var_data()
+                for key, value in self._var_value
+                if isinstance(key, Var)
+            ],
+            self._var_data,
+        )
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class LiteralArrayVar(LiteralVar):
+    """Base class for immutable literal array vars."""
+
+    _var_value: Union[
+        List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...]
+    ] = dataclasses.field(default_factory=list)
+
+    def __init__(
+        self,
+        _var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any],
+        _var_data: VarData | None = None,
+    ):
+        """Initialize the array var.
+
+        Args:
+            _var_value: The value of the var.
+            _var_data: Additional hooks and imports associated with the Var.
+        """
+        super(LiteralArrayVar, self).__init__(
+            _var_name="",
+            _var_data=ImmutableVarData.merge(_var_data),
+            _var_type=list,
         )
+        object.__setattr__(self, "_var_value", _var_value)
+        object.__delattr__(self, "_var_name")
+
+    def __getattr__(self, name):
+        """Get an attribute of the var.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The attribute of the var.
+        """
+        if name == "_var_name":
+            return self._cached_var_name
+        return super(type(self), self).__getattr__(name)
+
+    @cached_property
+    def _cached_var_name(self) -> str:
+        """The name of the var.
+
+        Returns:
+            The name of the var.
+        """
+        return (
+            "["
+            + ", ".join(
+                [str(LiteralVar.create(element)) for element in self._var_value]
+            )
+            + "]"
+        )
+
+    @cached_property
+    def _get_all_var_data(self) -> ImmutableVarData | None:
+        """Get all VarData associated with the Var.
+
+        Returns:
+            The VarData of the components and all of its children.
+        """
+        return ImmutableVarData.merge(
+            *[
+                var._get_all_var_data()
+                for var in self._var_value
+                if isinstance(var, Var)
+            ],
+            self._var_data,
+        )
+
+
+type_mapping = {
+    int: LiteralNumberVar,
+    float: LiteralNumberVar,
+    bool: LiteralBooleanVar,
+    dict: LiteralObjectVar,
+    list: LiteralArrayVar,
+    tuple: LiteralArrayVar,
+    set: LiteralArrayVar,
+}

+ 56 - 1
tests/test_var.py

@@ -8,9 +8,12 @@ from pandas import DataFrame
 from reflex.base import Base
 from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
 from reflex.experimental.vars.base import (
+    ArgsFunctionOperation,
     ConcatVarOperation,
+    FunctionStringVar,
     ImmutableVar,
     LiteralStringVar,
+    LiteralVar,
 )
 from reflex.state import BaseState
 from reflex.utils.imports import ImportVar
@@ -858,6 +861,58 @@ def test_state_with_initial_computed_var(
         assert runtime_dict[var_name] == expected_runtime
 
 
+def test_literal_var():
+    complicated_var = LiteralVar.create(
+        [
+            {"a": 1, "b": 2, "c": {"d": 3, "e": 4}},
+            [1, 2, 3, 4],
+            9,
+            "string",
+            True,
+            False,
+            None,
+            set([1, 2, 3]),
+        ]
+    )
+    assert (
+        str(complicated_var)
+        == '[{ ["a"] : 1, ["b"] : 2, ["c"] : { ["d"] : 3, ["e"] : 4 } }, [1, 2, 3, 4], 9, "string", true, false, null, [1, 2, 3]]'
+    )
+
+
+def test_function_var():
+    addition_func = FunctionStringVar("((a, b) => a + b)")
+    assert str(addition_func.call(1, 2)) == "(((a, b) => a + b)(1, 2))"
+
+    manual_addition_func = ArgsFunctionOperation(
+        ("a", "b"),
+        {
+            "args": [ImmutableVar.create_safe("a"), ImmutableVar.create_safe("b")],
+            "result": ImmutableVar.create_safe("a + b"),
+        },
+    )
+    assert (
+        str(manual_addition_func.call(1, 2))
+        == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
+    )
+
+    increment_func = addition_func(1)
+    assert (
+        str(increment_func.call(2))
+        == "(((...args) => ((((a, b) => a + b)(1, ...args))))(2))"
+    )
+
+    create_hello_statement = ArgsFunctionOperation(
+        ("name",), f"Hello, {ImmutableVar.create_safe('name')}!"
+    )
+    first_name = LiteralStringVar("Steven")
+    last_name = LiteralStringVar("Universe")
+    assert (
+        str(create_hello_statement.call(f"{first_name} {last_name}"))
+        == '(((name) => (("Hello, "+name+"!")))(("Steven"+" "+"Universe")))'
+    )
+
+
 def test_retrival():
     var_without_data = ImmutableVar.create("test")
     assert var_without_data is not None
@@ -931,7 +986,7 @@ def test_fstring_concat():
         ),
     )
 
-    assert str(string_concat) == '"foo"+imagination+"bar"+consequences+"baz"'
+    assert str(string_concat) == '("foo"+imagination+"bar"+consequences+"baz")'
     assert isinstance(string_concat, ConcatVarOperation)
     assert string_concat._get_all_var_data() == ImmutableVarData(
         state="fear",