Quellcode durchsuchen

Get `rx.color` working with fstrings (#2562)

* fix for rx.color working with fstrings

* Fix fstrings issues
Elijah Ahianyo vor 1 Jahr
Ursprung
Commit
ccc9c32c95

+ 2 - 3
reflex/components/core/colors.py

@@ -2,11 +2,10 @@
 
 from reflex.constants.colors import Color, ColorType, ShadeType
 from reflex.utils.types import validate_parameter_literals
-from reflex.vars import Var
 
 
 @validate_parameter_literals
-def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Var:
+def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Color:
     """Create a color object.
 
     Args:
@@ -17,4 +16,4 @@ def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Var:
     Returns:
         The color object.
     """
-    return Var.create(Color(color, shade, alpha))._replace(_var_is_string=True)  # type: ignore
+    return Color(color, shade, alpha)

+ 12 - 0
reflex/components/core/cond.py

@@ -7,6 +7,7 @@ from reflex.components.base.fragment import Fragment
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
 from reflex.components.tags import CondTag, Tag
 from reflex.constants import Dirs
+from reflex.constants.colors import Color
 from reflex.utils import format, imports
 from reflex.vars import BaseVar, Var, VarData
 
@@ -167,6 +168,17 @@ def cond(condition: Any, c1: Any, c2: Any = None):
     if isinstance(c2, Var):
         var_datas.append(c2._var_data)
 
+    def create_var(cond_part):
+        return Var.create_safe(
+            cond_part,
+            _var_is_string=type(cond_part) is str or isinstance(cond_part, Color),
+        )
+
+    # convert the truth and false cond parts into vars so the _var_data can be obtained.
+    c1 = create_var(c1)
+    c2 = create_var(c2)
+    var_datas.extend([c1._var_data, c2._var_data])
+
     # Create the conditional var.
     return cond_var._replace(
         _var_name=format.format_cond(

+ 3 - 1
reflex/components/core/match.py

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 
 from reflex.components.base import Fragment
 from reflex.components.component import BaseComponent, Component, MemoizationLeaf
+from reflex.components.core.colors import Color
 from reflex.components.tags import MatchTag, Tag
 from reflex.style import Style
 from reflex.utils import format, imports, types
@@ -116,7 +117,8 @@ class Match(MemoizationLeaf):
         """
         _var_data = case_element._var_data if isinstance(case_element, Style) else None  # type: ignore
         case_element = Var.create(
-            case_element, _var_is_string=type(case_element) is str
+            case_element,
+            _var_is_string=type(case_element) is str or isinstance(case_element, Color),
         )
         if _var_data is not None:
             case_element._var_data = VarData.merge(case_element._var_data, _var_data)  # type: ignore

+ 4 - 0
reflex/constants/__init__.py

@@ -6,6 +6,8 @@ from .base import (
     LOCAL_STORAGE,
     POLLING_MAX_HTTP_BUFFER_SIZE,
     PYTEST_CURRENT_TEST,
+    REFLEX_VAR_CLOSING_TAG,
+    REFLEX_VAR_OPENING_TAG,
     RELOAD_CONFIG,
     SKIP_COMPILE_ENV_VAR,
     ColorMode,
@@ -73,6 +75,8 @@ __ALL__ = [
     Expiration,
     Ext,
     Fnm,
+    REFLEX_VAR_CLOSING_TAG,
+    REFLEX_VAR_OPENING_TAG,
     GitIgnore,
     Hooks,
     Imports,

+ 3 - 0
reflex/constants/base.py

@@ -187,3 +187,6 @@ SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
 # Testing os env set by pytest when running a test case.
 PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
 RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"
+
+REFLEX_VAR_OPENING_TAG = "<reflex.Var>"
+REFLEX_VAR_CLOSING_TAG = "</reflex.Var>"

+ 14 - 0
reflex/utils/types.py

@@ -25,6 +25,7 @@ from pydantic.fields import ModelField
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
 
+from reflex import constants
 from reflex.base import Base
 from reflex.utils import serializers
 
@@ -332,6 +333,18 @@ def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
     return type_ in allowed_types
 
 
+def is_encoded_fstring(value) -> bool:
+    """Check if a value is an encoded Var f-string.
+
+    Args:
+        value: The value string to check.
+
+    Returns:
+        Whether the value is an f-string
+    """
+    return isinstance(value, str) and constants.REFLEX_VAR_OPENING_TAG in value
+
+
 def validate_literal(key: str, value: Any, expected_type: Type, comp_name: str):
     """Check that a value is a valid literal.
 
@@ -349,6 +362,7 @@ def validate_literal(key: str, value: Any, expected_type: Type, comp_name: str):
     if (
         is_literal(expected_type)
         and not isinstance(value, Var)  # validating vars is not supported yet.
+        and not is_encoded_fstring(value)  # f-strings are not supported.
         and value not in expected_type.__args__
     ):
         allowed_values = expected_type.__args__

+ 5 - 2
reflex/vars.py

@@ -202,7 +202,10 @@ def _encode_var(value: Var) -> str:
         The encoded var.
     """
     if value._var_data:
-        return f"<reflex.Var>{value._var_data.json()}</reflex.Var>" + str(value)
+        return (
+            f"{constants.REFLEX_VAR_OPENING_TAG}{value._var_data.json()}{constants.REFLEX_VAR_CLOSING_TAG}"
+            + str(value)
+        )
     return str(value)
 
 
@@ -219,7 +222,7 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
     if isinstance(value, str):
         # Extract the state name from a formatted var
         while m := re.match(
-            pattern=r"(.*)<reflex.Var>(.*)</reflex.Var>(.*)",
+            pattern=rf"(.*){constants.REFLEX_VAR_OPENING_TAG}(.*){constants.REFLEX_VAR_CLOSING_TAG}(.*)",
             string=value,
             flags=re.DOTALL,  # Ensure . matches newline characters.
         ):

+ 27 - 5
tests/components/core/test_colors.py

@@ -1,24 +1,46 @@
 import pytest
 
 import reflex as rx
+from reflex.vars import Var
 
 
 class ColorState(rx.State):
     """Test color state."""
 
     color: str = "mint"
+    color_part: str = "tom"
     shade: int = 4
 
 
+def create_color_var(color):
+    return Var.create(color)
+
+
 @pytest.mark.parametrize(
     "color, expected",
     [
-        (rx.color("mint"), "{`var(--mint-7)`}"),
-        (rx.color("mint", 3), "{`var(--mint-3)`}"),
-        (rx.color("mint", 3, True), "{`var(--mint-a3)`}"),
+        (create_color_var(rx.color("mint")), "var(--mint-7)"),
+        (create_color_var(rx.color("mint", 3)), "var(--mint-3)"),
+        (create_color_var(rx.color("mint", 3, True)), "var(--mint-a3)"),
+        (
+            create_color_var(rx.color(ColorState.color, ColorState.shade)),  # type: ignore
+            "var(--${state__color_state.color}-${state__color_state.shade})",
+        ),
+        (
+            create_color_var(rx.color(f"{ColorState.color}", f"{ColorState.shade}")),  # type: ignore
+            "var(--${state__color_state.color}-${state__color_state.shade})",
+        ),
+        (
+            create_color_var(rx.color(f"{ColorState.color_part}ato", f"{ColorState.shade}")),  # type: ignore
+            "var(--${state__color_state.color_part}ato-${state__color_state.shade})",
+        ),
+        (
+            create_color_var(f'{rx.color(ColorState.color, f"{ColorState.shade}")}'),  # type: ignore
+            "var(--${state__color_state.color}-${state__color_state.shade})",
+        ),
         (
-            rx.color(ColorState.color, ColorState.shade),  # type: ignore
-            "{`var(--${state__color_state.color}-${state__color_state.shade})`}",
+            create_color_var(f'{rx.color(f"{ColorState.color}", f"{ColorState.shade}")}'),  # type: ignore
+            "var(--${state__color_state.color}-${state__color_state.shade})",
         ),
     ],
 )