Bläddra i källkod

[REF-3228] implement LiteralStringVar and format/retrieval mechanism (#3669)

* implement LiteralStringVar and format/retrieval mechanism

* use create safe

* add cached properties to ConcatVarOperation

* fix caches

* also include self

* fix inconsistencies in typings

* use default factory not default

* add missing docstring

* experiment with immutable var data

* solve pydantic issues

* add sorted function

* missing docs

* forgot ellipses

* give up on frozen

* dang it darglint

* fix string serialization bugs and remove unused code

* add returns statement

* whitespace moment

* add simple test for string concat

* export ConcatVarOperation
Khaleel Al-Adhami 10 månader sedan
förälder
incheckning
458cbfac59
6 ändrade filer med 563 tillägg och 29 borttagningar
  1. 3 0
      reflex/experimental/vars/__init__.py
  2. 226 8
      reflex/experimental/vars/base.py
  3. 71 8
      reflex/utils/imports.py
  4. 185 6
      reflex/vars.py
  5. 20 2
      reflex/vars.pyi
  6. 58 5
      tests/test_var.py

+ 3 - 0
reflex/experimental/vars/__init__.py

@@ -2,8 +2,11 @@
 
 
 from .base import ArrayVar as ArrayVar
 from .base import ArrayVar as ArrayVar
 from .base import BooleanVar as BooleanVar
 from .base import BooleanVar as BooleanVar
+from .base import ConcatVarOperation as ConcatVarOperation
 from .base import FunctionVar as FunctionVar
 from .base import FunctionVar as FunctionVar
 from .base import ImmutableVar as ImmutableVar
 from .base import ImmutableVar as ImmutableVar
+from .base import LiteralStringVar as LiteralStringVar
+from .base import LiteralVar as LiteralVar
 from .base import NumberVar as NumberVar
 from .base import NumberVar as NumberVar
 from .base import ObjectVar as ObjectVar
 from .base import ObjectVar as ObjectVar
 from .base import StringVar as StringVar
 from .base import StringVar as StringVar

+ 226 - 8
reflex/experimental/vars/base.py

@@ -3,13 +3,24 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import dataclasses
 import dataclasses
+import json
+import re
 import sys
 import sys
+from functools import cached_property
 from typing import Any, Optional, Type
 from typing import Any, Optional, Type
 
 
+from reflex import constants
 from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
 from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
 from reflex.utils import serializers, types
 from reflex.utils import serializers, types
 from reflex.utils.exceptions import VarTypeError
 from reflex.utils.exceptions import VarTypeError
-from reflex.vars import Var, VarData, _decode_var, _extract_var_data, _global_vars
+from reflex.vars import (
+    ImmutableVarData,
+    Var,
+    VarData,
+    _decode_var_immutable,
+    _extract_var_data,
+    _global_vars,
+)
 
 
 
 
 @dataclasses.dataclass(
 @dataclasses.dataclass(
@@ -27,7 +38,15 @@ class ImmutableVar(Var):
     _var_type: Type = dataclasses.field(default=Any)
     _var_type: Type = dataclasses.field(default=Any)
 
 
     # Extra metadata associated with the Var
     # Extra metadata associated with the Var
-    _var_data: Optional[VarData] = dataclasses.field(default=None)
+    _var_data: Optional[ImmutableVarData] = dataclasses.field(default=None)
+
+    def __str__(self) -> str:
+        """String representation of the var. Guaranteed to be a valid Javascript expression.
+
+        Returns:
+            The name of the var.
+        """
+        return self._var_name
 
 
     @property
     @property
     def _var_is_local(self) -> bool:
     def _var_is_local(self) -> bool:
@@ -59,12 +78,25 @@ class ImmutableVar(Var):
     def __post_init__(self):
     def __post_init__(self):
         """Post-initialize the var."""
         """Post-initialize the var."""
         # Decode any inline Var markup and apply it to the instance
         # Decode any inline Var markup and apply it to the instance
-        _var_data, _var_name = _decode_var(self._var_name)
+        _var_data, _var_name = _decode_var_immutable(self._var_name)
         if _var_data:
         if _var_data:
             self.__init__(
             self.__init__(
-                _var_name, self._var_type, VarData.merge(self._var_data, _var_data)
+                _var_name,
+                self._var_type,
+                ImmutableVarData.merge(self._var_data, _var_data),
             )
             )
 
 
+    def __hash__(self) -> int:
+        """Define a hash function for the var.
+
+        Returns:
+            The hash of the var.
+        """
+        return hash((self._var_name, self._var_type, self._var_data))
+
+    def _get_all_var_data(self) -> ImmutableVarData | None:
+        return self._var_data
+
     def _replace(self, merge_var_data=None, **kwargs: Any):
     def _replace(self, merge_var_data=None, **kwargs: Any):
         """Make a copy of this Var with updated fields.
         """Make a copy of this Var with updated fields.
 
 
@@ -96,7 +128,7 @@ class ImmutableVar(Var):
         field_values = dict(
         field_values = dict(
             _var_name=kwargs.pop("_var_name", self._var_name),
             _var_name=kwargs.pop("_var_name", self._var_name),
             _var_type=kwargs.pop("_var_type", self._var_type),
             _var_type=kwargs.pop("_var_type", self._var_type),
-            _var_data=VarData.merge(
+            _var_data=ImmutableVarData.merge(
                 kwargs.get("_var_data", self._var_data), merge_var_data
                 kwargs.get("_var_data", self._var_data), merge_var_data
             ),
             ),
         )
         )
@@ -109,7 +141,7 @@ class ImmutableVar(Var):
         _var_is_local: bool | None = None,
         _var_is_local: bool | None = None,
         _var_is_string: bool | None = None,
         _var_is_string: bool | None = None,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
-    ) -> Var | None:
+    ) -> ImmutableVar | Var | None:
         """Create a var from a value.
         """Create a var from a value.
 
 
         Args:
         Args:
@@ -164,7 +196,15 @@ class ImmutableVar(Var):
         return cls(
         return cls(
             _var_name=name,
             _var_name=name,
             _var_type=type_,
             _var_type=type_,
-            _var_data=_var_data,
+            _var_data=(
+                ImmutableVarData(
+                    state=_var_data.state,
+                    imports=_var_data.imports,
+                    hooks=_var_data.hooks,
+                )
+                if _var_data
+                else None
+            ),
         )
         )
 
 
     @classmethod
     @classmethod
@@ -174,7 +214,7 @@ class ImmutableVar(Var):
         _var_is_local: bool | None = None,
         _var_is_local: bool | None = None,
         _var_is_string: bool | None = None,
         _var_is_string: bool | None = None,
         _var_data: VarData | None = None,
         _var_data: VarData | None = None,
-    ) -> Var:
+    ) -> Var | ImmutableVar:
         """Create a var from a value, asserting that it is not None.
         """Create a var from a value, asserting that it is not None.
 
 
         Args:
         Args:
@@ -234,3 +274,181 @@ class ArrayVar(ImmutableVar):
 
 
 class FunctionVar(ImmutableVar):
 class FunctionVar(ImmutableVar):
     """Base class for immutable function vars."""
     """Base class for immutable function vars."""
+
+
+class LiteralVar(ImmutableVar):
+    """Base class for immutable literal vars."""
+
+    def __post_init__(self):
+        """Post-initialize the var."""
+
+
+# Compile regex for finding reflex var tags.
+_decode_var_pattern_re = (
+    rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
+)
+_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class LiteralStringVar(LiteralVar):
+    """Base class for immutable literal string vars."""
+
+    _var_value: Optional[str] = dataclasses.field(default=None)
+
+    @classmethod
+    def create(
+        cls,
+        value: str,
+        _var_data: VarData | None = None,
+    ) -> LiteralStringVar | ConcatVarOperation:
+        """Create a var from a string value.
+
+        Args:
+            value: The value to create the var from.
+            _var_data: Additional hooks and imports associated with the Var.
+
+        Returns:
+            The var.
+        """
+        if REFLEX_VAR_OPENING_TAG in value:
+            strings_and_vals: list[Var] = []
+            offset = 0
+
+            # Initialize some methods for reading json.
+            var_data_config = VarData().__config__
+
+            def json_loads(s):
+                try:
+                    return var_data_config.json_loads(s)
+                except json.decoder.JSONDecodeError:
+                    return var_data_config.json_loads(
+                        var_data_config.json_loads(f'"{s}"')
+                    )
+
+            # Find all tags.
+            while m := _decode_var_pattern.search(value):
+                start, end = m.span()
+                if start > 0:
+                    strings_and_vals.append(LiteralStringVar.create(value[:start]))
+
+                serialized_data = m.group(1)
+
+                if serialized_data[1:].isnumeric():
+                    # This is a global immutable var.
+                    var = _global_vars[int(serialized_data)]
+                    strings_and_vals.append(var)
+                    value = value[(end + len(var._var_name)) :]
+                else:
+                    data = json_loads(serialized_data)
+                    string_length = data.pop("string_length", None)
+                    var_data = VarData.parse_obj(data)
+
+                    # Use string length to compute positions of interpolations.
+                    if string_length is not None:
+                        realstart = start + offset
+                        var_data.interpolations = [
+                            (realstart, realstart + string_length)
+                        ]
+                        strings_and_vals.append(
+                            ImmutableVar.create_safe(
+                                value[end : (end + string_length)], _var_data=var_data
+                            )
+                        )
+                        value = value[(end + string_length) :]
+
+                offset += end - start
+
+            if value:
+                strings_and_vals.append(LiteralStringVar.create(value))
+
+            return ConcatVarOperation.create(
+                tuple(strings_and_vals), _var_data=_var_data
+            )
+
+        return cls(
+            _var_value=value,
+            _var_name=f'"{value}"',
+            _var_type=str,
+            _var_data=ImmutableVarData.merge(_var_data),
+        )
+
+
+@dataclasses.dataclass(
+    eq=False,
+    frozen=True,
+    **{"slots": True} if sys.version_info >= (3, 10) else {},
+)
+class ConcatVarOperation(StringVar):
+    """Representing a concatenation of literal string vars."""
+
+    _var_value: tuple[Var, ...] = dataclasses.field(default_factory=tuple)
+
+    def __init__(self, _var_value: tuple[Var, ...], _var_data: VarData | None = None):
+        """Initialize the operation of concatenating literal string vars.
+
+        Args:
+            _var_value: The list of vars to concatenate.
+            _var_data: Additional hooks and imports associated with the Var.
+        """
+        super(ConcatVarOperation, self).__init__(
+            _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str
+        )
+        object.__setattr__(self, "_var_value", _var_value)
+        object.__setattr__(self, "_var_name", self._cached_var_name)
+
+    @cached_property
+    def _cached_var_name(self) -> str:
+        """The name of the var.
+
+        Returns:
+            The name of the var.
+        """
+        return "+".join([str(element) for element in self._var_value])
+
+    @cached_property
+    def _cached_get_all_var_data(self) -> ImmutableVarData | None:
+        """Get all VarData associated with the Var.
+
+        Returns:
+            The VarData of the components and all of its children.
+        """
+        return ImmutableVarData.merge(
+            *[var._get_all_var_data() for var in self._var_value], self._var_data
+        )
+
+    def _get_all_var_data(self) -> ImmutableVarData | None:
+        """Wrapper method for cached property.
+
+        Returns:
+            The VarData of the components and all of its children.
+        """
+        return self._cached_get_all_var_data
+
+    def __post_init__(self):
+        """Post-initialize the var."""
+        pass
+
+    @classmethod
+    def create(
+        cls,
+        value: tuple[Var, ...],
+        _var_data: VarData | None = None,
+    ) -> ConcatVarOperation:
+        """Create a var from a tuple of values.
+
+        Args:
+            value: The value to create the var from.
+            _var_data: Additional hooks and imports associated with the Var.
+
+        Returns:
+            The var.
+        """
+        return ConcatVarOperation(
+            _var_value=value,
+            _var_data=_var_data,
+        )

+ 71 - 8
reflex/utils/imports.py

@@ -3,12 +3,14 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 from collections import defaultdict
 from collections import defaultdict
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional, Tuple, Union
 
 
 from reflex.base import Base
 from reflex.base import Base
 
 
 
 
-def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
+def merge_imports(
+    *imports: ImportDict | ParsedImportDict | ImmutableParsedImportDict,
+) -> ParsedImportDict:
     """Merge multiple import dicts together.
     """Merge multiple import dicts together.
 
 
     Args:
     Args:
@@ -19,7 +21,9 @@ def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
     """
     """
     all_imports = defaultdict(list)
     all_imports = defaultdict(list)
     for import_dict in imports:
     for import_dict in imports:
-        for lib, fields in import_dict.items():
+        for lib, fields in (
+            import_dict if isinstance(import_dict, tuple) else import_dict.items()
+        ):
             all_imports[lib].extend(fields)
             all_imports[lib].extend(fields)
     return all_imports
     return all_imports
 
 
@@ -48,7 +52,9 @@ def parse_imports(imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
     }
     }
 
 
 
 
-def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict:
+def collapse_imports(
+    imports: ParsedImportDict | ImmutableParsedImportDict,
+) -> ParsedImportDict:
     """Remove all duplicate ImportVar within an ImportDict.
     """Remove all duplicate ImportVar within an ImportDict.
 
 
     Args:
     Args:
@@ -58,8 +64,14 @@ def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict:
         The collapsed import dict.
         The collapsed import dict.
     """
     """
     return {
     return {
-        lib: list(set(import_vars)) if isinstance(import_vars, list) else import_vars
-        for lib, import_vars in imports.items()
+        lib: (
+            list(set(import_vars))
+            if isinstance(import_vars, list)
+            else list(import_vars)
+        )
+        for lib, import_vars in (
+            imports if isinstance(imports, tuple) else imports.items()
+        )
     }
     }
 
 
 
 
@@ -99,11 +111,61 @@ class ImportVar(Base):
         else:
         else:
             return self.tag or ""
             return self.tag or ""
 
 
+    def __lt__(self, other: ImportVar) -> bool:
+        """Compare two ImportVar objects.
+
+        Args:
+            other: The other ImportVar object to compare.
+
+        Returns:
+            Whether this ImportVar object is less than the other.
+        """
+        return (
+            self.tag,
+            self.is_default,
+            self.alias,
+            self.install,
+            self.render,
+            self.transpile,
+        ) < (
+            other.tag,
+            other.is_default,
+            other.alias,
+            other.install,
+            other.render,
+            other.transpile,
+        )
+
+    def __eq__(self, other: ImportVar) -> bool:
+        """Check if two ImportVar objects are equal.
+
+        Args:
+            other: The other ImportVar object to compare.
+
+        Returns:
+            Whether the two ImportVar objects are equal.
+        """
+        return (
+            self.tag,
+            self.is_default,
+            self.alias,
+            self.install,
+            self.render,
+            self.transpile,
+        ) == (
+            other.tag,
+            other.is_default,
+            other.alias,
+            other.install,
+            other.render,
+            other.transpile,
+        )
+
     def __hash__(self) -> int:
     def __hash__(self) -> int:
-        """Define a hash function for the import var.
+        """Hash the ImportVar object.
 
 
         Returns:
         Returns:
-            The hash of the var.
+            The hash of the ImportVar object.
         """
         """
         return hash(
         return hash(
             (
             (
@@ -120,3 +182,4 @@ class ImportVar(Base):
 ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]]
 ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]]
 ImportDict = Dict[str, ImportTypes]
 ImportDict = Dict[str, ImportTypes]
 ParsedImportDict = Dict[str, List[ImportVar]]
 ParsedImportDict = Dict[str, List[ImportVar]]
+ImmutableParsedImportDict = Tuple[Tuple[str, Tuple[ImportVar, ...]], ...]

+ 185 - 6
reflex/vars.py

@@ -45,6 +45,7 @@ from reflex.utils.exceptions import (
 
 
 # This module used to export ImportVar itself, so we still import it for export here
 # This module used to export ImportVar itself, so we still import it for export here
 from reflex.utils.imports import (
 from reflex.utils.imports import (
+    ImmutableParsedImportDict,
     ImportDict,
     ImportDict,
     ImportVar,
     ImportVar,
     ParsedImportDict,
     ParsedImportDict,
@@ -154,7 +155,7 @@ class VarData(Base):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
 
 
     @classmethod
     @classmethod
-    def merge(cls, *others: VarData | None) -> VarData | None:
+    def merge(cls, *others: ImmutableVarData | VarData | None) -> VarData | None:
         """Merge multiple var data objects.
         """Merge multiple var data objects.
 
 
         Args:
         Args:
@@ -172,8 +173,14 @@ class VarData(Base):
                 continue
                 continue
             state = state or var_data.state
             state = state or var_data.state
             _imports = imports.merge_imports(_imports, var_data.imports)
             _imports = imports.merge_imports(_imports, var_data.imports)
-            hooks.update(var_data.hooks)
-            interpolations += var_data.interpolations
+            hooks.update(
+                var_data.hooks
+                if isinstance(var_data.hooks, dict)
+                else {k: None for k in var_data.hooks}
+            )
+            interpolations += (
+                var_data.interpolations if isinstance(var_data, VarData) else []
+            )
 
 
         return (
         return (
             cls(
             cls(
@@ -231,6 +238,173 @@ class VarData(Base):
         }
         }
 
 
 
 
+@dataclasses.dataclass(
+    eq=True,
+    frozen=True,
+)
+class ImmutableVarData:
+    """Metadata associated with a Var."""
+
+    # The name of the enclosing state.
+    state: str = dataclasses.field(default="")
+
+    # Imports needed to render this var
+    imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple)
+
+    # Hooks that need to be present in the component to render this var
+    hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
+
+    def __init__(
+        self,
+        state: str = "",
+        imports: ImportDict | ParsedImportDict | None = None,
+        hooks: dict[str, None] | None = None,
+    ):
+        """Initialize the var data.
+
+        Args:
+            state: The name of the enclosing state.
+            imports: Imports needed to render this var.
+            hooks: Hooks that need to be present in the component to render this var.
+        """
+        immutable_imports: ImmutableParsedImportDict = tuple(
+            sorted(
+                ((k, tuple(sorted(v))) for k, v in parse_imports(imports or {}).items())
+            )
+        )
+        object.__setattr__(self, "state", state)
+        object.__setattr__(self, "imports", immutable_imports)
+        object.__setattr__(self, "hooks", tuple(hooks or {}))
+
+    @classmethod
+    def merge(
+        cls, *others: ImmutableVarData | VarData | None
+    ) -> ImmutableVarData | None:
+        """Merge multiple var data objects.
+
+        Args:
+            *others: The var data objects to merge.
+
+        Returns:
+            The merged var data object.
+        """
+        state = ""
+        _imports = {}
+        hooks = {}
+        for var_data in others:
+            if var_data is None:
+                continue
+            state = state or var_data.state
+            _imports = imports.merge_imports(_imports, var_data.imports)
+            hooks.update(
+                var_data.hooks
+                if isinstance(var_data.hooks, dict)
+                else {k: None for k in var_data.hooks}
+            )
+
+        return (
+            ImmutableVarData(
+                state=state,
+                imports=_imports,
+                hooks=hooks,
+            )
+            or None
+        )
+
+    def __bool__(self) -> bool:
+        """Check if the var data is non-empty.
+
+        Returns:
+            True if any field is set to a non-default value.
+        """
+        return bool(self.state or self.imports or self.hooks)
+
+    def __eq__(self, other: Any) -> bool:
+        """Check if two var data objects are equal.
+
+        Args:
+            other: The other var data object to compare.
+
+        Returns:
+            True if all fields are equal and collapsed imports are equal.
+        """
+        if not isinstance(other, (ImmutableVarData, VarData)):
+            return False
+
+        # Don't compare interpolations - that's added in by the decoder, and
+        # not part of the vardata itself.
+        return (
+            self.state == other.state
+            and self.hooks
+            == (
+                other.hooks
+                if isinstance(other, ImmutableVarData)
+                else tuple(other.hooks.keys())
+            )
+            and imports.collapse_imports(self.imports)
+            == imports.collapse_imports(other.imports)
+        )
+
+
+def _decode_var_immutable(value: str) -> tuple[ImmutableVarData | None, str]:
+    """Decode the state name from a formatted var.
+
+    Args:
+        value: The value to extract the state name from.
+
+    Returns:
+        The extracted state name and the value without the state name.
+    """
+    var_datas = []
+    if isinstance(value, str):
+        # fast path if there is no encoded VarData
+        if constants.REFLEX_VAR_OPENING_TAG not in value:
+            return None, value
+
+        offset = 0
+
+        # Initialize some methods for reading json.
+        var_data_config = VarData().__config__
+
+        def json_loads(s):
+            try:
+                return var_data_config.json_loads(s)
+            except json.decoder.JSONDecodeError:
+                return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"'))
+
+        # Find all tags.
+        while m := _decode_var_pattern.search(value):
+            start, end = m.span()
+            value = value[:start] + value[end:]
+
+            serialized_data = m.group(1)
+
+            if serialized_data[1:].isnumeric():
+                # This is a global immutable var.
+                var = _global_vars[int(serialized_data)]
+                var_data = var._var_data
+
+                if var_data is not None:
+                    realstart = start + offset
+
+                    var_datas.append(var_data)
+            else:
+                # Read the JSON, pull out the string length, parse the rest as VarData.
+                data = json_loads(serialized_data)
+                string_length = data.pop("string_length", None)
+                var_data = VarData.parse_obj(data)
+
+                # Use string length to compute positions of interpolations.
+                if string_length is not None:
+                    realstart = start + offset
+                    var_data.interpolations = [(realstart, realstart + string_length)]
+
+                var_datas.append(var_data)
+            offset += end - start
+
+    return ImmutableVarData.merge(*var_datas) if var_datas else None, value
+
+
 def _encode_var(value: Var) -> str:
 def _encode_var(value: Var) -> str:
     """Encode the state name into a formatted var.
     """Encode the state name into a formatted var.
 
 
@@ -306,9 +480,6 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
 
 
                 if var_data is not None:
                 if var_data is not None:
                     realstart = start + offset
                     realstart = start + offset
-                    var_data.interpolations = [
-                        (realstart, realstart + len(var._var_name))
-                    ]
 
 
                     var_datas.append(var_data)
                     var_datas.append(var_data)
             else:
             else:
@@ -1814,6 +1985,14 @@ class Var:
         """
         """
         return self._var_data.state if self._var_data else ""
         return self._var_data.state if self._var_data else ""
 
 
+    def _get_all_var_data(self) -> VarData | None:
+        """Get all the var data.
+
+        Returns:
+            The var data.
+        """
+        return self._var_data
+
     @property
     @property
     def _var_name_unwrapped(self) -> str:
     def _var_name_unwrapped(self) -> str:
         """Get the var str without wrapping in curly braces.
         """Get the var str without wrapping in curly braces.

+ 20 - 2
reflex/vars.pyi

@@ -29,7 +29,7 @@ from reflex.state import State as State
 from reflex.utils import console as console
 from reflex.utils import console as console
 from reflex.utils import format as format
 from reflex.utils import format as format
 from reflex.utils import types as types
 from reflex.utils import types as types
-from reflex.utils.imports import ImportDict, ParsedImportDict
+from reflex.utils.imports import ImmutableParsedImportDict, ImportDict, ParsedImportDict
 
 
 USED_VARIABLES: Incomplete
 USED_VARIABLES: Incomplete
 
 
@@ -47,7 +47,24 @@ class VarData(Base):
     hooks: Dict[str, None] = {}
     hooks: Dict[str, None] = {}
     interpolations: List[Tuple[int, int]] = []
     interpolations: List[Tuple[int, int]] = []
     @classmethod
     @classmethod
-    def merge(cls, *others: VarData | None) -> VarData | None: ...
+    def merge(cls, *others: ImmutableVarData | VarData | None) -> VarData | None: ...
+
+class ImmutableVarData:
+    state: str = ""
+    imports: ImmutableParsedImportDict = tuple()
+    hooks: Tuple[str, ...] = tuple()
+    def __init__(
+        self,
+        state: str = "",
+        imports: ImportDict | ParsedImportDict | None = None,
+        hooks: dict[str, None] | None = None,
+    ) -> None: ...
+    @classmethod
+    def merge(
+        cls, *others: ImmutableVarData | VarData | None
+    ) -> ImmutableVarData | None: ...
+
+def _decode_var_immutable(value: str) -> tuple[ImmutableVarData, str]: ...
 
 
 class Var:
 class Var:
     _var_name: str
     _var_name: str
@@ -133,6 +150,7 @@ class Var:
     @property
     @property
     def _var_full_name(self) -> str: ...
     def _var_full_name(self) -> str: ...
     def _var_set_state(self, state: Type[BaseState] | str) -> Any: ...
     def _var_set_state(self, state: Type[BaseState] | str) -> Any: ...
+    def _get_all_var_data(self) -> VarData: ...
 
 
 @dataclass(eq=False)
 @dataclass(eq=False)
 class BaseVar(Var):
 class BaseVar(Var):

+ 58 - 5
tests/test_var.py

@@ -7,12 +7,17 @@ from pandas import DataFrame
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
 from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
-from reflex.experimental.vars.base import ImmutableVar
+from reflex.experimental.vars.base import (
+    ConcatVarOperation,
+    ImmutableVar,
+    LiteralStringVar,
+)
 from reflex.state import BaseState
 from reflex.state import BaseState
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
 from reflex.vars import (
 from reflex.vars import (
     BaseVar,
     BaseVar,
     ComputedVar,
     ComputedVar,
+    ImmutableVarData,
     Var,
     Var,
     VarData,
     VarData,
     computed_var,
     computed_var,
@@ -880,13 +885,61 @@ def test_retrival():
     )
     )
     assert (
     assert (
         result_var_data.imports
         result_var_data.imports
-        == result_immutable_var_data.imports
+        == (
+            result_immutable_var_data.imports
+            if isinstance(result_immutable_var_data.imports, dict)
+            else {
+                k: list(v)
+                for k, v in result_immutable_var_data.imports
+                if k in original_var_data.imports
+            }
+        )
         == original_var_data.imports
         == original_var_data.imports
     )
     )
     assert (
     assert (
-        result_var_data.hooks
-        == result_immutable_var_data.hooks
-        == original_var_data.hooks
+        list(result_var_data.hooks.keys())
+        == (
+            list(result_immutable_var_data.hooks.keys())
+            if isinstance(result_immutable_var_data.hooks, dict)
+            else list(result_immutable_var_data.hooks)
+        )
+        == list(original_var_data.hooks.keys())
+    )
+
+
+def test_fstring_concat():
+    original_var_with_data = Var.create_safe(
+        "imagination", _var_data=VarData(state="fear")
+    )
+
+    immutable_var_with_data = ImmutableVar.create_safe(
+        "consequences",
+        _var_data=VarData(
+            imports={
+                "react": [ImportVar(tag="useRef")],
+                "utils": [ImportVar(tag="useEffect")],
+            }
+        ),
+    )
+
+    f_string = f"foo{original_var_with_data}bar{immutable_var_with_data}baz"
+
+    string_concat = LiteralStringVar.create(
+        f_string,
+        _var_data=VarData(
+            hooks={"const state = useContext(StateContexts.state)": None}
+        ),
+    )
+
+    assert str(string_concat) == '"foo"+imagination+"bar"+consequences+"baz"'
+    assert isinstance(string_concat, ConcatVarOperation)
+    assert string_concat._get_all_var_data() == ImmutableVarData(
+        state="fear",
+        imports={
+            "react": [ImportVar(tag="useRef")],
+            "utils": [ImportVar(tag="useEffect")],
+        },
+        hooks={"const state = useContext(StateContexts.state)": None},
     )
     )