瀏覽代碼

allow classname to be state vars (#3991)

* allow classname to be state vars

* simplify join with all literal string vars

* add test case and avoid concat var operation if it's not necessary

* remove silly print statement

* simplify case where there's no var

* don't automatically do class name str to literal var
Khaleel Al-Adhami 8 月之前
父節點
當前提交
74d1c47ce2

+ 7 - 1
reflex/components/component.py

@@ -55,6 +55,7 @@ from reflex.utils.imports import (
 )
 from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
+from reflex.vars.sequence import LiteralArrayVar
 
 
 class BaseComponent(Base, ABC):
@@ -496,7 +497,12 @@ class Component(BaseComponent, ABC):
         # Convert class_name to str if it's list
         class_name = kwargs.get("class_name", "")
         if isinstance(class_name, (List, tuple)):
-            kwargs["class_name"] = " ".join(class_name)
+            if any(isinstance(c, Var) for c in class_name):
+                kwargs["class_name"] = LiteralArrayVar.create(
+                    class_name, _var_type=List[str]
+                ).join(" ")
+            else:
+                kwargs["class_name"] = " ".join(class_name)
 
         # Construct the component.
         super().__init__(*args, **kwargs)

+ 1 - 1
reflex/components/radix/themes/layout/stack.py

@@ -33,7 +33,7 @@ class Stack(Flex):
         """
         # Apply the default classname
         given_class_name = props.pop("class_name", [])
-        if isinstance(given_class_name, str):
+        if not isinstance(given_class_name, list):
             given_class_name = [given_class_name]
         props["class_name"] = ["rx-Stack", *given_class_name]
 

+ 43 - 0
reflex/vars/sequence.py

@@ -592,6 +592,29 @@ class LiteralStringVar(LiteralVar, StringVar):
                 else:
                     return only_string.to(StringVar, only_string._var_type)
 
+            if len(
+                literal_strings := [
+                    s
+                    for s in filtered_strings_and_vals
+                    if isinstance(s, (str, LiteralStringVar))
+                ]
+            ) == len(filtered_strings_and_vals):
+                return LiteralStringVar.create(
+                    "".join(
+                        s._var_value if isinstance(s, LiteralStringVar) else s
+                        for s in literal_strings
+                    ),
+                    _var_type=_var_type,
+                    _var_data=VarData.merge(
+                        _var_data,
+                        *(
+                            s._get_all_var_data()
+                            for s in filtered_strings_and_vals
+                            if isinstance(s, Var)
+                        ),
+                    ),
+                )
+
             concat_result = ConcatVarOperation.create(
                 *filtered_strings_and_vals,
                 _var_data=_var_data,
@@ -736,6 +759,26 @@ class ArrayVar(Var[ARRAY_VAR_TYPE]):
         """
         if not isinstance(sep, (StringVar, str)):
             raise_unsupported_operand_types("join", (type(self), type(sep)))
+        if (
+            isinstance(self, LiteralArrayVar)
+            and (
+                len(
+                    args := [
+                        x
+                        for x in self._var_value
+                        if isinstance(x, (LiteralStringVar, str))
+                    ]
+                )
+                == len(self._var_value)
+            )
+            and isinstance(sep, (LiteralStringVar, str))
+        ):
+            sep_str = sep._var_value if isinstance(sep, LiteralStringVar) else sep
+            return LiteralStringVar.create(
+                sep_str.join(
+                    i._var_value if isinstance(i, LiteralStringVar) else i for i in args
+                )
+            )
         return array_join_operation(self, sep)
 
     def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]:

+ 10 - 0
tests/components/test_component.py

@@ -1288,6 +1288,16 @@ class EventState(rx.State):
             [FORMATTED_TEST_VAR],
             id="fstring-class_name",
         ),
+        pytest.param(
+            rx.fragment(class_name=f"foo{TEST_VAR}bar other-class"),
+            [LiteralVar.create(f"{FORMATTED_TEST_VAR} other-class")],
+            id="fstring-dual-class_name",
+        ),
+        pytest.param(
+            rx.fragment(class_name=[TEST_VAR, "other-class"]),
+            [LiteralVar.create([TEST_VAR, "other-class"]).join(" ")],
+            id="fstring-dual-class_name",
+        ),
         pytest.param(
             rx.fragment(special_props=[TEST_VAR]),
             [TEST_VAR],