Explorar o código

Get `rx.color` working with fstrings (#2562)

* fix for rx.color working with fstrings

* Fix fstrings issues
Elijah Ahianyo hai 1 ano
pai
achega
ccc9c32c95

+ 2 - 3
reflex/components/core/colors.py

@@ -2,11 +2,10 @@
 
 
 from reflex.constants.colors import Color, ColorType, ShadeType
 from reflex.constants.colors import Color, ColorType, ShadeType
 from reflex.utils.types import validate_parameter_literals
 from reflex.utils.types import validate_parameter_literals
-from reflex.vars import Var
 
 
 
 
 @validate_parameter_literals
 @validate_parameter_literals
-def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Var:
+def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Color:
     """Create a color object.
     """Create a color object.
 
 
     Args:
     Args:
@@ -17,4 +16,4 @@ def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Var:
     Returns:
     Returns:
         The color object.
         The color object.
     """
     """
-    return Var.create(Color(color, shade, alpha))._replace(_var_is_string=True)  # type: ignore
+    return Color(color, shade, alpha)

+ 12 - 0
reflex/components/core/cond.py

@@ -7,6 +7,7 @@ from reflex.components.base.fragment import Fragment
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
 from reflex.components.tags import CondTag, Tag
 from reflex.components.tags import CondTag, Tag
 from reflex.constants import Dirs
 from reflex.constants import Dirs
+from reflex.constants.colors import Color
 from reflex.utils import format, imports
 from reflex.utils import format, imports
 from reflex.vars import BaseVar, Var, VarData
 from reflex.vars import BaseVar, Var, VarData
 
 
@@ -167,6 +168,17 @@ def cond(condition: Any, c1: Any, c2: Any = None):
     if isinstance(c2, Var):
     if isinstance(c2, Var):
         var_datas.append(c2._var_data)
         var_datas.append(c2._var_data)
 
 
+    def create_var(cond_part):
+        return Var.create_safe(
+            cond_part,
+            _var_is_string=type(cond_part) is str or isinstance(cond_part, Color),
+        )
+
+    # convert the truth and false cond parts into vars so the _var_data can be obtained.
+    c1 = create_var(c1)
+    c2 = create_var(c2)
+    var_datas.extend([c1._var_data, c2._var_data])
+
     # Create the conditional var.
     # Create the conditional var.
     return cond_var._replace(
     return cond_var._replace(
         _var_name=format.format_cond(
         _var_name=format.format_cond(

+ 3 - 1
reflex/components/core/match.py

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 
 
 from reflex.components.base import Fragment
 from reflex.components.base import Fragment
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
+from reflex.components.core.colors import Color
 from reflex.components.tags import MatchTag, Tag
 from reflex.components.tags import MatchTag, Tag
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import format, imports, types
 from reflex.utils import format, imports, types
@@ -116,7 +117,8 @@ class Match(MemoizationLeaf):
         """
         """
         _var_data = case_element._var_data if isinstance(case_element, Style) else None  # type: ignore
         _var_data = case_element._var_data if isinstance(case_element, Style) else None  # type: ignore
         case_element = Var.create(
         case_element = Var.create(
-            case_element, _var_is_string=type(case_element) is str
+            case_element,
+            _var_is_string=type(case_element) is str or isinstance(case_element, Color),
         )
         )
         if _var_data is not None:
         if _var_data is not None:
             case_element._var_data = VarData.merge(case_element._var_data, _var_data)  # type: ignore
             case_element._var_data = VarData.merge(case_element._var_data, _var_data)  # type: ignore

+ 4 - 0
reflex/constants/__init__.py

@@ -6,6 +6,8 @@ from .base import (
     LOCAL_STORAGE,
     LOCAL_STORAGE,
     POLLING_MAX_HTTP_BUFFER_SIZE,
     POLLING_MAX_HTTP_BUFFER_SIZE,
     PYTEST_CURRENT_TEST,
     PYTEST_CURRENT_TEST,
+    REFLEX_VAR_CLOSING_TAG,
+    REFLEX_VAR_OPENING_TAG,
     RELOAD_CONFIG,
     RELOAD_CONFIG,
     SKIP_COMPILE_ENV_VAR,
     SKIP_COMPILE_ENV_VAR,
     ColorMode,
     ColorMode,
@@ -73,6 +75,8 @@ __ALL__ = [
     Expiration,
     Expiration,
     Ext,
     Ext,
     Fnm,
     Fnm,
+    REFLEX_VAR_CLOSING_TAG,
+    REFLEX_VAR_OPENING_TAG,
     GitIgnore,
     GitIgnore,
     Hooks,
     Hooks,
     Imports,
     Imports,

+ 3 - 0
reflex/constants/base.py

@@ -187,3 +187,6 @@ SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
 # Testing os env set by pytest when running a test case.
 # Testing os env set by pytest when running a test case.
 PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
 PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
 RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"
 RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"
+
+REFLEX_VAR_OPENING_TAG = "<reflex.Var>"
+REFLEX_VAR_CLOSING_TAG = "</reflex.Var>"

+ 14 - 0
reflex/utils/types.py

@@ -25,6 +25,7 @@ from pydantic.fields import ModelField
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
 from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
 
 
+from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
 from reflex.utils import serializers
 from reflex.utils import serializers
 
 
@@ -332,6 +333,18 @@ def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
     return type_ in allowed_types
     return type_ in allowed_types
 
 
 
 
+def is_encoded_fstring(value) -> bool:
+    """Check if a value is an encoded Var f-string.
+
+    Args:
+        value: The value string to check.
+
+    Returns:
+        Whether the value is an f-string
+    """
+    return isinstance(value, str) and constants.REFLEX_VAR_OPENING_TAG in value
+
+
 def validate_literal(key: str, value: Any, expected_type: Type, comp_name: str):
 def validate_literal(key: str, value: Any, expected_type: Type, comp_name: str):
     """Check that a value is a valid literal.
     """Check that a value is a valid literal.
 
 
@@ -349,6 +362,7 @@ def validate_literal(key: str, value: Any, expected_type: Type, comp_name: str):
     if (
     if (
         is_literal(expected_type)
         is_literal(expected_type)
         and not isinstance(value, Var)  # validating vars is not supported yet.
         and not isinstance(value, Var)  # validating vars is not supported yet.
+        and not is_encoded_fstring(value)  # f-strings are not supported.
         and value not in expected_type.__args__
         and value not in expected_type.__args__
     ):
     ):
         allowed_values = expected_type.__args__
         allowed_values = expected_type.__args__

+ 5 - 2
reflex/vars.py

@@ -202,7 +202,10 @@ def _encode_var(value: Var) -> str:
         The encoded var.
         The encoded var.
     """
     """
     if value._var_data:
     if value._var_data:
-        return f"<reflex.Var>{value._var_data.json()}</reflex.Var>" + str(value)
+        return (
+            f"{constants.REFLEX_VAR_OPENING_TAG}{value._var_data.json()}{constants.REFLEX_VAR_CLOSING_TAG}"
+            + str(value)
+        )
     return str(value)
     return str(value)
 
 
 
 
@@ -219,7 +222,7 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
     if isinstance(value, str):
     if isinstance(value, str):
         # Extract the state name from a formatted var
         # Extract the state name from a formatted var
         while m := re.match(
         while m := re.match(
-            pattern=r"(.*)<reflex.Var>(.*)</reflex.Var>(.*)",
+            pattern=rf"(.*){constants.REFLEX_VAR_OPENING_TAG}(.*){constants.REFLEX_VAR_CLOSING_TAG}(.*)",
             string=value,
             string=value,
             flags=re.DOTALL,  # Ensure . matches newline characters.
             flags=re.DOTALL,  # Ensure . matches newline characters.
         ):
         ):

+ 27 - 5
tests/components/core/test_colors.py

@@ -1,24 +1,46 @@
 import pytest
 import pytest
 
 
 import reflex as rx
 import reflex as rx
+from reflex.vars import Var
 
 
 
 
 class ColorState(rx.State):
 class ColorState(rx.State):
     """Test color state."""
     """Test color state."""
 
 
     color: str = "mint"
     color: str = "mint"
+    color_part: str = "tom"
     shade: int = 4
     shade: int = 4
 
 
 
 
+def create_color_var(color):
+    return Var.create(color)
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "color, expected",
     "color, expected",
     [
     [
-        (rx.color("mint"), "{`var(--mint-7)`}"),
-        (rx.color("mint", 3), "{`var(--mint-3)`}"),
-        (rx.color("mint", 3, True), "{`var(--mint-a3)`}"),
+        (create_color_var(rx.color("mint")), "var(--mint-7)"),
+        (create_color_var(rx.color("mint", 3)), "var(--mint-3)"),
+        (create_color_var(rx.color("mint", 3, True)), "var(--mint-a3)"),
+        (
+            create_color_var(rx.color(ColorState.color, ColorState.shade)),  # type: ignore
+            "var(--${state__color_state.color}-${state__color_state.shade})",
+        ),
+        (
+            create_color_var(rx.color(f"{ColorState.color}", f"{ColorState.shade}")),  # type: ignore
+            "var(--${state__color_state.color}-${state__color_state.shade})",
+        ),
+        (
+            create_color_var(rx.color(f"{ColorState.color_part}ato", f"{ColorState.shade}")),  # type: ignore
+            "var(--${state__color_state.color_part}ato-${state__color_state.shade})",
+        ),
+        (
+            create_color_var(f'{rx.color(ColorState.color, f"{ColorState.shade}")}'),  # type: ignore
+            "var(--${state__color_state.color}-${state__color_state.shade})",
+        ),
         (
         (
-            rx.color(ColorState.color, ColorState.shade),  # type: ignore
-            "{`var(--${state__color_state.color}-${state__color_state.shade})`}",
+            create_color_var(f'{rx.color(f"{ColorState.color}", f"{ColorState.shade}")}'),  # type: ignore
+            "var(--${state__color_state.color}-${state__color_state.shade})",
         ),
         ),
     ],
     ],
 )
 )