Pārlūkot izejas kodu

Fix fstrings being escaped improperly (#2571)

invrainbow 1 gadu atpakaļ
vecāks
revīzija
e729a315f8

+ 56 - 7
reflex/utils/format.py

@@ -188,26 +188,73 @@ def to_kebab_case(text: str) -> str:
     return to_snake_case(text).replace("_", "-")
     return to_snake_case(text).replace("_", "-")
 
 
 
 
-def format_string(string: str) -> str:
-    """Format the given string as a JS string literal..
+def _escape_js_string(string: str) -> str:
+    """Escape the string for use as a JS string literal.
 
 
     Args:
     Args:
-        string: The string to format.
+        string: The string to escape.
 
 
     Returns:
     Returns:
-        The formatted string.
+        The escaped string.
     """
     """
     # Escape backticks.
     # Escape backticks.
     string = string.replace(r"\`", "`")
     string = string.replace(r"\`", "`")
     string = string.replace("`", r"\`")
     string = string.replace("`", r"\`")
+    return string
+
+
+def _wrap_js_string(string: str) -> str:
+    """Wrap string so it looks like {`string`}.
+
+    Args:
+        string: The string to wrap.
 
 
-    # Wrap the string so it looks like {`string`}.
+    Returns:
+        The wrapped string.
+    """
     string = wrap(string, "`")
     string = wrap(string, "`")
     string = wrap(string, "{")
     string = wrap(string, "{")
-
     return string
     return string
 
 
 
 
+def format_string(string: str) -> str:
+    """Format the given string as a JS string literal..
+
+    Args:
+        string: The string to format.
+
+    Returns:
+        The formatted string.
+    """
+    return _wrap_js_string(_escape_js_string(string))
+
+
+def format_f_string_prop(prop: BaseVar) -> str:
+    """Format the string in a given prop as an f-string.
+
+    Args:
+        prop: The prop to format.
+
+    Returns:
+        The formatted string.
+    """
+    s = prop._var_full_name
+    var_data = prop._var_data
+    interps = var_data.interpolations if var_data else []
+    parts: List[str] = []
+
+    if interps:
+        for i, (start, end) in enumerate(interps):
+            prev_end = interps[i - 1][1] if i > 0 else 0
+            parts.append(_escape_js_string(s[prev_end:start]))
+            parts.append(s[start:end])
+        parts.append(_escape_js_string(s[interps[-1][1] :]))
+    else:
+        parts.append(_escape_js_string(s))
+
+    return _wrap_js_string("".join(parts))
+
+
 def format_var(var: Var) -> str:
 def format_var(var: Var) -> str:
     """Format the given Var as a javascript value.
     """Format the given Var as a javascript value.
 
 
@@ -345,7 +392,9 @@ def format_prop(
         if isinstance(prop, Var):
         if isinstance(prop, Var):
             if not prop._var_is_local or prop._var_is_string:
             if not prop._var_is_local or prop._var_is_string:
                 return str(prop)
                 return str(prop)
-            if types._issubclass(prop._var_type, str):
+            if isinstance(prop, BaseVar) and types._issubclass(prop._var_type, str):
+                if prop._var_data and prop._var_data.interpolations:
+                    return format_f_string_prop(prop)
                 return format_string(prop._var_full_name)
                 return format_string(prop._var_full_name)
             prop = prop._var_full_name
             prop = prop._var_full_name
 
 

+ 57 - 19
reflex/vars.py

@@ -31,8 +31,6 @@ from typing import (
     get_type_hints,
     get_type_hints,
 )
 )
 
 
-import pydantic
-
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
 from reflex.utils import console, format, imports, serializers, types
 from reflex.utils import console, format, imports, serializers, types
@@ -122,6 +120,11 @@ class VarData(Base):
     # Hooks that need to be present in the component to render this var
     # Hooks that need to be present in the component to render this var
     hooks: Set[str] = set()
     hooks: Set[str] = set()
 
 
+    # Positions of interpolated strings. This is used by the decoder to figure
+    # out where the interpolations are and only escape the non-interpolated
+    # segments.
+    interpolations: List[Tuple[int, int]] = []
+
     @classmethod
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None:
     def merge(cls, *others: VarData | None) -> VarData | None:
         """Merge multiple var data objects.
         """Merge multiple var data objects.
@@ -135,17 +138,21 @@ class VarData(Base):
         state = ""
         state = ""
         _imports = {}
         _imports = {}
         hooks = set()
         hooks = set()
+        interpolations = []
         for var_data in others:
         for var_data in others:
             if var_data is None:
             if var_data is None:
                 continue
                 continue
             state = state or var_data.state
             state = state or var_data.state
             _imports = imports.merge_imports(_imports, var_data.imports)
             _imports = imports.merge_imports(_imports, var_data.imports)
             hooks.update(var_data.hooks)
             hooks.update(var_data.hooks)
+            interpolations += var_data.interpolations
+
         return (
         return (
             cls(
             cls(
                 state=state,
                 state=state,
                 imports=_imports,
                 imports=_imports,
                 hooks=hooks,
                 hooks=hooks,
+                interpolations=interpolations,
             )
             )
             or None
             or None
         )
         )
@@ -156,7 +163,7 @@ class VarData(Base):
         Returns:
         Returns:
             True if any field is set to a non-default value.
             True if any field is set to a non-default value.
         """
         """
-        return bool(self.state or self.imports or self.hooks)
+        return bool(self.state or self.imports or self.hooks or self.interpolations)
 
 
     def __eq__(self, other: Any) -> bool:
     def __eq__(self, other: Any) -> bool:
         """Check if two var data objects are equal.
         """Check if two var data objects are equal.
@@ -169,6 +176,9 @@ class VarData(Base):
         """
         """
         if not isinstance(other, VarData):
         if not isinstance(other, VarData):
             return False
             return False
+
+        # Don't compare interpolations - that's added in by the decoder, and
+        # not part of the vardata itself.
         return (
         return (
             self.state == other.state
             self.state == other.state
             and self.hooks == other.hooks
             and self.hooks == other.hooks
@@ -184,6 +194,7 @@ class VarData(Base):
         """
         """
         return {
         return {
             "state": self.state,
             "state": self.state,
+            "interpolations": list(self.interpolations),
             "imports": {
             "imports": {
                 lib: [import_var.dict() for import_var in import_vars]
                 lib: [import_var.dict() for import_var in import_vars]
                 for lib, import_vars in self.imports.items()
                 for lib, import_vars in self.imports.items()
@@ -202,10 +213,18 @@ def _encode_var(value: Var) -> str:
         The encoded var.
         The encoded var.
     """
     """
     if value._var_data:
     if value._var_data:
+        from reflex.utils.serializers import serialize
+
+        final_value = str(value)
+        data = value._var_data.dict()
+        data["string_length"] = len(final_value)
+        data_json = value._var_data.__config__.json_dumps(data, default=serialize)
+
         return (
         return (
-            f"{constants.REFLEX_VAR_OPENING_TAG}{value._var_data.json()}{constants.REFLEX_VAR_CLOSING_TAG}"
-            + str(value)
+            f"{constants.REFLEX_VAR_OPENING_TAG}{data_json}{constants.REFLEX_VAR_CLOSING_TAG}"
+            + final_value
         )
         )
+
     return str(value)
     return str(value)
 
 
 
 
@@ -220,21 +239,40 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
     """
     """
     var_datas = []
     var_datas = []
     if isinstance(value, str):
     if isinstance(value, str):
-        # Extract the state name from a formatted var
-        while m := re.match(
-            pattern=rf"(.*){constants.REFLEX_VAR_OPENING_TAG}(.*){constants.REFLEX_VAR_CLOSING_TAG}(.*)",
-            string=value,
-            flags=re.DOTALL,  # Ensure . matches newline characters.
-        ):
-            value = m.group(1) + m.group(3)
+        offset = 0
+
+        # Initialize some methods for reading json.
+        var_data_config = VarData().__config__
+
+        def json_loads(s):
             try:
             try:
-                var_datas.append(VarData.parse_raw(m.group(2)))
-            except pydantic.ValidationError:
-                # If the VarData is invalid, it was probably json-encoded twice...
-                var_datas.append(VarData.parse_raw(json.loads(f'"{m.group(2)}"')))
-    if var_datas:
-        return VarData.merge(*var_datas), value
-    return None, value
+                return var_data_config.json_loads(s)
+            except json.decoder.JSONDecodeError:
+                return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"'))
+
+        # Compile regex for finding reflex var tags.
+        pattern_re = rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
+        pattern = re.compile(pattern_re, flags=re.DOTALL)
+
+        # Find all tags.
+        while m := pattern.search(value):
+            start, end = m.span()
+            value = value[:start] + value[end:]
+
+            # Read the JSON, pull out the string length, parse the rest as VarData.
+            data = json_loads(m.group(1))
+            string_length = data.pop("string_length", None)
+            var_data = VarData.parse_obj(data)
+
+            # Use string length to compute positions of interpolations.
+            if string_length is not None:
+                realstart = start + offset
+                var_data.interpolations = [(realstart, realstart + string_length)]
+
+            var_datas.append(var_data)
+            offset += end - start
+
+    return VarData.merge(*var_datas) if var_datas else None, value
 
 
 
 
 def _extract_var_data(value: Iterable) -> list[VarData | None]:
 def _extract_var_data(value: Iterable) -> list[VarData | None]:

+ 3 - 0
reflex/vars.pyi

@@ -1,4 +1,5 @@
 """ Generated with stubgen from mypy, then manually edited, do not regen."""
 """ Generated with stubgen from mypy, then manually edited, do not regen."""
+from __future__ import annotations
 
 
 from dataclasses import dataclass
 from dataclasses import dataclass
 from _typeshed import Incomplete
 from _typeshed import Incomplete
@@ -17,6 +18,7 @@ from typing import (
     List,
     List,
     Optional,
     Optional,
     Set,
     Set,
+    Tuple,
     Type,
     Type,
     Union,
     Union,
     overload,
     overload,
@@ -34,6 +36,7 @@ class VarData(Base):
     state: str
     state: str
     imports: dict[str, set[ImportVar]]
     imports: dict[str, set[ImportVar]]
     hooks: set[str]
     hooks: set[str]
+    interpolations: List[Tuple[int, int]]
     @classmethod
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None: ...
     def merge(cls, *others: VarData | None) -> VarData | None: ...
 
 

+ 6 - 0
tests/components/layout/test_cond.py

@@ -27,6 +27,12 @@ def cond_state(request):
     return CondState
     return CondState
 
 
 
 
+def test_f_string_cond_interpolation():
+    # make sure backticks inside interpolation don't get escaped
+    var = Var.create(f"x {cond(True, 'a', 'b')}")
+    assert str(var) == "x ${isTrue(true) ? `a` : `b`}"
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "cond_state",
     "cond_state",
     [
     [

+ 4 - 1
tests/components/test_component.py

@@ -687,7 +687,10 @@ def test_stateful_banner():
 
 
 TEST_VAR = Var.create_safe("test")._replace(
 TEST_VAR = Var.create_safe("test")._replace(
     merge_var_data=VarData(
     merge_var_data=VarData(
-        hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test"
+        hooks={"useTest"},
+        imports={"test": {ImportVar(tag="test")}},
+        state="Test",
+        interpolations=[],
     )
     )
 )
 )
 FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")
 FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")

+ 1 - 1
tests/test_var.py

@@ -718,7 +718,7 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (
         (
             f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
             f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
-            'testing f-string with $<reflex.Var>{"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}</reflex.Var>{state.myvar}',
+            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}</reflex.Var>{state.myvar}',
         ),
         ),
         (
         (
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",