소스 검색

[REF-3225] implement __format__ for immutable vars (#3617)

* implement format for immutable vars

* add some basic test

* make reference only after formatting

* win over pyright

* hopefully now pyright doesn't hate me

* forgot some _var_data

* i don't know how imports work

* use f_string var and remove assignments from pyi file

* override post_init to not break immutability

* add create_safe and test for it
Khaleel Al-Adhami 10 달 전
부모
커밋
d4d077818c
4개의 변경된 파일124개의 추가작업 그리고 10개의 파일을 삭제
  1. 53 1
      reflex/experimental/vars/base.py
  2. 27 9
      reflex/vars.py
  3. 3 0
      reflex/vars.pyi
  4. 41 0
      tests/test_var.py

+ 53 - 1
reflex/experimental/vars/base.py

@@ -6,9 +6,10 @@ import dataclasses
 import sys
 import sys
 from typing import Any, Optional, Type
 from typing import Any, Optional, Type
 
 
+from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
 from reflex.utils import serializers, types
 from reflex.utils import serializers, types
 from reflex.utils.exceptions import VarTypeError
 from reflex.utils.exceptions import VarTypeError
-from reflex.vars import Var, VarData, _extract_var_data
+from reflex.vars import Var, VarData, _decode_var, _extract_var_data, _global_vars
 
 
 
 
 @dataclasses.dataclass(
 @dataclasses.dataclass(
@@ -55,6 +56,15 @@ class ImmutableVar(Var):
         """
         """
         return False
         return False
 
 
+    def __post_init__(self):
+        """Post-initialize the var."""
+        # Decode any inline Var markup and apply it to the instance
+        _var_data, _var_name = _decode_var(self._var_name)
+        if _var_data:
+            self.__init__(
+                _var_name, self._var_type, VarData.merge(self._var_data, _var_data)
+            )
+
     def _replace(self, merge_var_data=None, **kwargs: Any):
     def _replace(self, merge_var_data=None, **kwargs: Any):
         """Make a copy of this Var with updated fields.
         """Make a copy of this Var with updated fields.
 
 
@@ -156,3 +166,45 @@ class ImmutableVar(Var):
             _var_type=type_,
             _var_type=type_,
             _var_data=_var_data,
             _var_data=_var_data,
         )
         )
+
+    @classmethod
+    def create_safe(
+        cls,
+        value: Any,
+        _var_is_local: bool | None = None,
+        _var_is_string: bool | None = None,
+        _var_data: VarData | None = None,
+    ) -> Var:
+        """Create a var from a value, asserting that it is not None.
+
+        Args:
+            value: The value to create the var from.
+            _var_is_local: Whether the var is local. Deprecated.
+            _var_is_string: Whether the var is a string literal. Deprecated.
+            _var_data: Additional hooks and imports associated with the Var.
+
+        Returns:
+            The var.
+        """
+        var = cls.create(
+            value,
+            _var_is_local=_var_is_local,
+            _var_is_string=_var_is_string,
+            _var_data=_var_data,
+        )
+        assert var is not None
+        return var
+
+    def __format__(self, format_spec: str) -> str:
+        """Format the var into a Javascript equivalent to an f-string.
+
+        Args:
+            format_spec: The format specifier (Ignored for now).
+
+        Returns:
+            The formatted var.
+        """
+        _global_vars[hash(self)] = self
+
+        # Encode the _var_data into the formatted output for tracking purposes.
+        return f"{REFLEX_VAR_OPENING_TAG}{hash(self)}{REFLEX_VAR_CLOSING_TAG}{self._var_name}"

+ 27 - 9
reflex/vars.py

@@ -262,6 +262,9 @@ _decode_var_pattern_re = (
 )
 )
 _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
 _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
 
 
+# Defined global immutable vars.
+_global_vars: Dict[int, Var] = {}
+
 
 
 def _decode_var(value: str) -> tuple[VarData | None, str]:
 def _decode_var(value: str) -> tuple[VarData | None, str]:
     """Decode the state name from a formatted var.
     """Decode the state name from a formatted var.
@@ -294,17 +297,32 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
             start, end = m.span()
             start, end = m.span()
             value = value[:start] + value[end:]
             value = value[:start] + value[end:]
 
 
-            # Read the JSON, pull out the string length, parse the rest as VarData.
-            data = json_loads(m.group(1))
-            string_length = data.pop("string_length", None)
-            var_data = VarData.parse_obj(data)
+            serialized_data = m.group(1)
+
+            if serialized_data[1:].isnumeric():
+                # This is a global immutable var.
+                var = _global_vars[int(serialized_data)]
+                var_data = var._var_data
+
+                if var_data is not None:
+                    realstart = start + offset
+                    var_data.interpolations = [
+                        (realstart, realstart + len(var._var_name))
+                    ]
+
+                    var_datas.append(var_data)
+            else:
+                # Read the JSON, pull out the string length, parse the rest as VarData.
+                data = json_loads(serialized_data)
+                string_length = data.pop("string_length", None)
+                var_data = VarData.parse_obj(data)
 
 
-            # Use string length to compute positions of interpolations.
-            if string_length is not None:
-                realstart = start + offset
-                var_data.interpolations = [(realstart, realstart + string_length)]
+                # Use string length to compute positions of interpolations.
+                if string_length is not None:
+                    realstart = start + offset
+                    var_data.interpolations = [(realstart, realstart + string_length)]
 
 
-            var_datas.append(var_data)
+                var_datas.append(var_data)
             offset += end - start
             offset += end - start
 
 
     return VarData.merge(*var_datas) if var_datas else None, value
     return VarData.merge(*var_datas) if var_datas else None, value

+ 3 - 0
reflex/vars.pyi

@@ -35,6 +35,9 @@ USED_VARIABLES: Incomplete
 
 
 def get_unique_variable_name() -> str: ...
 def get_unique_variable_name() -> str: ...
 def _encode_var(value: Var) -> str: ...
 def _encode_var(value: Var) -> str: ...
+
+_global_vars: Dict[int, Var]
+
 def _decode_var(value: str) -> tuple[VarData, str]: ...
 def _decode_var(value: str) -> tuple[VarData, str]: ...
 def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
 def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
 
 

+ 41 - 0
tests/test_var.py

@@ -6,11 +6,15 @@ import pytest
 from pandas import DataFrame
 from pandas import DataFrame
 
 
 from reflex.base import Base
 from reflex.base import Base
+from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
+from reflex.experimental.vars.base import ImmutableVar
 from reflex.state import BaseState
 from reflex.state import BaseState
+from reflex.utils.imports import ImportVar
 from reflex.vars import (
 from reflex.vars import (
     BaseVar,
     BaseVar,
     ComputedVar,
     ComputedVar,
     Var,
     Var,
+    VarData,
     computed_var,
     computed_var,
 )
 )
 
 
@@ -849,6 +853,43 @@ def test_state_with_initial_computed_var(
         assert runtime_dict[var_name] == expected_runtime
         assert runtime_dict[var_name] == expected_runtime
 
 
 
 
+def test_retrival():
+    var_without_data = ImmutableVar.create("test")
+    assert var_without_data is not None
+
+    original_var_data = VarData(
+        state="Test",
+        imports={"react": [ImportVar(tag="useRef")]},
+        hooks={"const state = useContext(StateContexts.state)": None},
+    )
+
+    var_with_data = var_without_data._replace(merge_var_data=original_var_data)
+
+    f_string = f"foo{var_with_data}bar"
+
+    assert REFLEX_VAR_OPENING_TAG in f_string
+    assert REFLEX_VAR_CLOSING_TAG in f_string
+
+    result_var_data = Var.create_safe(f_string)._var_data
+    result_immutable_var_data = ImmutableVar.create_safe(f_string)._var_data
+    assert result_var_data is not None and result_immutable_var_data is not None
+    assert (
+        result_var_data.state
+        == result_immutable_var_data.state
+        == original_var_data.state
+    )
+    assert (
+        result_var_data.imports
+        == result_immutable_var_data.imports
+        == original_var_data.imports
+    )
+    assert (
+        result_var_data.hooks
+        == result_immutable_var_data.hooks
+        == original_var_data.hooks
+    )
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "out, expected",
     "out, expected",
     [
     [