1
0
Эх сурвалжийг харах

Fix fstrings being escaped improperly (#2571)

invrainbow 1 жил өмнө
parent
commit
e729a315f8

+ 56 - 7
reflex/utils/format.py

@@ -188,26 +188,73 @@ def to_kebab_case(text: str) -> str:
     return to_snake_case(text).replace("_", "-")
 
 
-def format_string(string: str) -> str:
-    """Format the given string as a JS string literal..
+def _escape_js_string(string: str) -> str:
+    """Escape the string for use as a JS string literal.
 
     Args:
-        string: The string to format.
+        string: The string to escape.
 
     Returns:
-        The formatted string.
+        The escaped string.
     """
     # Escape backticks.
     string = string.replace(r"\`", "`")
     string = string.replace("`", r"\`")
+    return string
+
+
+def _wrap_js_string(string: str) -> str:
+    """Wrap string so it looks like {`string`}.
+
+    Args:
+        string: The string to wrap.
 
-    # Wrap the string so it looks like {`string`}.
+    Returns:
+        The wrapped string.
+    """
     string = wrap(string, "`")
     string = wrap(string, "{")
-
     return string
 
 
+def format_string(string: str) -> str:
+    """Format the given string as a JS string literal..
+
+    Args:
+        string: The string to format.
+
+    Returns:
+        The formatted string.
+    """
+    return _wrap_js_string(_escape_js_string(string))
+
+
+def format_f_string_prop(prop: BaseVar) -> str:
+    """Format the string in a given prop as an f-string.
+
+    Args:
+        prop: The prop to format.
+
+    Returns:
+        The formatted string.
+    """
+    s = prop._var_full_name
+    var_data = prop._var_data
+    interps = var_data.interpolations if var_data else []
+    parts: List[str] = []
+
+    if interps:
+        for i, (start, end) in enumerate(interps):
+            prev_end = interps[i - 1][1] if i > 0 else 0
+            parts.append(_escape_js_string(s[prev_end:start]))
+            parts.append(s[start:end])
+        parts.append(_escape_js_string(s[interps[-1][1] :]))
+    else:
+        parts.append(_escape_js_string(s))
+
+    return _wrap_js_string("".join(parts))
+
+
 def format_var(var: Var) -> str:
     """Format the given Var as a javascript value.
 
@@ -345,7 +392,9 @@ def format_prop(
         if isinstance(prop, Var):
             if not prop._var_is_local or prop._var_is_string:
                 return str(prop)
-            if types._issubclass(prop._var_type, str):
+            if isinstance(prop, BaseVar) and types._issubclass(prop._var_type, str):
+                if prop._var_data and prop._var_data.interpolations:
+                    return format_f_string_prop(prop)
                 return format_string(prop._var_full_name)
             prop = prop._var_full_name
 

+ 57 - 19
reflex/vars.py

@@ -31,8 +31,6 @@ from typing import (
     get_type_hints,
 )
 
-import pydantic
-
 from reflex import constants
 from reflex.base import Base
 from reflex.utils import console, format, imports, serializers, types
@@ -122,6 +120,11 @@ class VarData(Base):
     # Hooks that need to be present in the component to render this var
     hooks: Set[str] = set()
 
+    # Positions of interpolated strings. This is used by the decoder to figure
+    # out where the interpolations are and only escape the non-interpolated
+    # segments.
+    interpolations: List[Tuple[int, int]] = []
+
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None:
         """Merge multiple var data objects.
@@ -135,17 +138,21 @@ class VarData(Base):
         state = ""
         _imports = {}
         hooks = set()
+        interpolations = []
         for var_data in others:
             if var_data is None:
                 continue
             state = state or var_data.state
             _imports = imports.merge_imports(_imports, var_data.imports)
             hooks.update(var_data.hooks)
+            interpolations += var_data.interpolations
+
         return (
             cls(
                 state=state,
                 imports=_imports,
                 hooks=hooks,
+                interpolations=interpolations,
             )
             or None
         )
@@ -156,7 +163,7 @@ class VarData(Base):
         Returns:
             True if any field is set to a non-default value.
         """
-        return bool(self.state or self.imports or self.hooks)
+        return bool(self.state or self.imports or self.hooks or self.interpolations)
 
     def __eq__(self, other: Any) -> bool:
         """Check if two var data objects are equal.
@@ -169,6 +176,9 @@ class VarData(Base):
         """
         if not isinstance(other, VarData):
             return False
+
+        # Don't compare interpolations - that's added in by the decoder, and
+        # not part of the vardata itself.
         return (
             self.state == other.state
             and self.hooks == other.hooks
@@ -184,6 +194,7 @@ class VarData(Base):
         """
         return {
             "state": self.state,
+            "interpolations": list(self.interpolations),
             "imports": {
                 lib: [import_var.dict() for import_var in import_vars]
                 for lib, import_vars in self.imports.items()
@@ -202,10 +213,18 @@ def _encode_var(value: Var) -> str:
         The encoded var.
     """
     if value._var_data:
+        from reflex.utils.serializers import serialize
+
+        final_value = str(value)
+        data = value._var_data.dict()
+        data["string_length"] = len(final_value)
+        data_json = value._var_data.__config__.json_dumps(data, default=serialize)
+
         return (
-            f"{constants.REFLEX_VAR_OPENING_TAG}{value._var_data.json()}{constants.REFLEX_VAR_CLOSING_TAG}"
-            + str(value)
+            f"{constants.REFLEX_VAR_OPENING_TAG}{data_json}{constants.REFLEX_VAR_CLOSING_TAG}"
+            + final_value
         )
+
     return str(value)
 
 
@@ -220,21 +239,40 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
     """
     var_datas = []
     if isinstance(value, str):
-        # Extract the state name from a formatted var
-        while m := re.match(
-            pattern=rf"(.*){constants.REFLEX_VAR_OPENING_TAG}(.*){constants.REFLEX_VAR_CLOSING_TAG}(.*)",
-            string=value,
-            flags=re.DOTALL,  # Ensure . matches newline characters.
-        ):
-            value = m.group(1) + m.group(3)
+        offset = 0
+
+        # Initialize some methods for reading json.
+        var_data_config = VarData().__config__
+
+        def json_loads(s):
             try:
-                var_datas.append(VarData.parse_raw(m.group(2)))
-            except pydantic.ValidationError:
-                # If the VarData is invalid, it was probably json-encoded twice...
-                var_datas.append(VarData.parse_raw(json.loads(f'"{m.group(2)}"')))
-    if var_datas:
-        return VarData.merge(*var_datas), value
-    return None, value
+                return var_data_config.json_loads(s)
+            except json.decoder.JSONDecodeError:
+                return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"'))
+
+        # Compile regex for finding reflex var tags.
+        pattern_re = rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
+        pattern = re.compile(pattern_re, flags=re.DOTALL)
+
+        # Find all tags.
+        while m := pattern.search(value):
+            start, end = m.span()
+            value = value[:start] + value[end:]
+
+            # Read the JSON, pull out the string length, parse the rest as VarData.
+            data = json_loads(m.group(1))
+            string_length = data.pop("string_length", None)
+            var_data = VarData.parse_obj(data)
+
+            # Use string length to compute positions of interpolations.
+            if string_length is not None:
+                realstart = start + offset
+                var_data.interpolations = [(realstart, realstart + string_length)]
+
+            var_datas.append(var_data)
+            offset += end - start
+
+    return VarData.merge(*var_datas) if var_datas else None, value
 
 
 def _extract_var_data(value: Iterable) -> list[VarData | None]:

+ 3 - 0
reflex/vars.pyi

@@ -1,4 +1,5 @@
 """ Generated with stubgen from mypy, then manually edited, do not regen."""
+from __future__ import annotations
 
 from dataclasses import dataclass
 from _typeshed import Incomplete
@@ -17,6 +18,7 @@ from typing import (
     List,
     Optional,
     Set,
+    Tuple,
     Type,
     Union,
     overload,
@@ -34,6 +36,7 @@ class VarData(Base):
     state: str
     imports: dict[str, set[ImportVar]]
     hooks: set[str]
+    interpolations: List[Tuple[int, int]]
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None: ...
 

+ 6 - 0
tests/components/layout/test_cond.py

@@ -27,6 +27,12 @@ def cond_state(request):
     return CondState
 
 
+def test_f_string_cond_interpolation():
+    # make sure backticks inside interpolation don't get escaped
+    var = Var.create(f"x {cond(True, 'a', 'b')}")
+    assert str(var) == "x ${isTrue(true) ? `a` : `b`}"
+
+
 @pytest.mark.parametrize(
     "cond_state",
     [

+ 4 - 1
tests/components/test_component.py

@@ -687,7 +687,10 @@ def test_stateful_banner():
 
 TEST_VAR = Var.create_safe("test")._replace(
     merge_var_data=VarData(
-        hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test"
+        hooks={"useTest"},
+        imports={"test": {ImportVar(tag="test")}},
+        state="Test",
+        interpolations=[],
     )
 )
 FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")

+ 1 - 1
tests/test_var.py

@@ -718,7 +718,7 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (
             f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
-            'testing f-string with $<reflex.Var>{"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}</reflex.Var>{state.myvar}',
+            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}</reflex.Var>{state.myvar}',
         ),
         (
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",