Browse Source

Allow `Component.add_style` to return a regular dict (#3264)

* Allow `Component.add_style` to return a regular dict

It's more convenient to allow returning a regular dict without having to import
and wrap the value in `rx.style.Style`.

If the dict contains any Var or encoded VarData f-strings, these will be picked
up when the plain dicts are passed to Style.update().

Because Style.update already merges VarData, there is no reason to explicitly
merge it again in this function; this change keeps the merging logic inside the
Style class.

* Test for Style.update with existing Style with _var_data and kwargs

Should retain the _var_data from the original Style instance

* style: Avoid losing VarData in Style.update

If a Style class with _var_data is passed to `Style.update` along with kwargs,
then the _var_data was lost in the double-splat dictionary expansion.

Instead, only apply the kwargs to an existing or new Style instance to retain
_var_data and properly convert values.

* add_style return annotation is Dict[str, Any]

* nit: use lowercase dict in annotation
Masen Furer 1 year ago
parent
commit
3715462eb4
4 changed files with 78 additions and 8 deletions
  1. 1 5
      reflex/components/component.py
  2. 5 2
      reflex/style.py
  3. 49 0
      tests/components/test_component.py
  4. 23 1
      tests/test_style.py

+ 1 - 5
reflex/components/component.py

@@ -782,7 +782,7 @@ class Component(BaseComponent, ABC):
 
         return cls(children=children, **props)
 
-    def add_style(self) -> Style | None:
+    def add_style(self) -> dict[str, Any] | None:
         """Add style to the component.
 
         Downstream components can override this method to return a style dict
@@ -802,20 +802,16 @@ class Component(BaseComponent, ABC):
             The style to add.
         """
         styles = []
-        vars = []
 
         # Walk the MRO to call all `add_style` methods.
         for base in self._iter_parent_classes_with_method("add_style"):
             s = base.add_style(self)  # type: ignore
             if s is not None:
                 styles.append(s)
-                vars.append(s._var_data)
 
         _style = Style()
         for s in reversed(styles):
             _style.update(s)
-
-        _style._var_data = VarData.merge(*vars)
         return _style
 
     def _get_component_style(self, styles: ComponentStyle) -> Style | None:

+ 5 - 2
reflex/style.py

@@ -180,12 +180,15 @@ class Style(dict):
             style_dict: The style dictionary.
             kwargs: Other key value pairs to apply to the dict update.
         """
-        if kwargs:
-            style_dict = {**(style_dict or {}), **kwargs}
         if not isinstance(style_dict, Style):
             converted_dict = type(self)(style_dict)
         else:
             converted_dict = style_dict
+        if kwargs:
+            if converted_dict is None:
+                converted_dict = type(self)(kwargs)
+            else:
+                converted_dict.update(kwargs)
         # Combine our VarData with that of any Vars in the style_dict that was passed.
         self._var_data = VarData.merge(self._var_data, converted_dict._var_data)
         super().update(converted_dict)

+ 49 - 0
tests/components/test_component.py

@@ -1953,6 +1953,55 @@ def test_component_add_custom_code():
     }
 
 
+def test_add_style_embedded_vars(test_state: BaseState):
+    """Test that add_style works with embedded vars when returning a plain dict.
+
+    Args:
+        test_state: A test state.
+    """
+    v0 = Var.create_safe("parent")._replace(
+        merge_var_data=VarData(hooks={"useParent": None}),  # type: ignore
+    )
+    v1 = rx.color("plum", 10)
+    v2 = Var.create_safe("text")._replace(
+        merge_var_data=VarData(hooks={"useText": None}),  # type: ignore
+    )
+
+    class ParentComponent(Component):
+        def add_style(self):
+            return Style(
+                {
+                    "fake_parent": v0,
+                }
+            )
+
+    class StyledComponent(ParentComponent):
+        tag = "StyledComponent"
+
+        def add_style(self):
+            return {
+                "color": v1,
+                "fake": v2,
+                "margin": f"{test_state.num}%",
+            }
+
+    page = rx.vstack(StyledComponent.create())
+    page._add_style_recursive(Style())
+
+    assert (
+        "const test_state = useContext(StateContexts.test_state)"
+        in page._get_all_hooks_internal()
+    )
+    assert "useText" in page._get_all_hooks_internal()
+    assert "useParent" in page._get_all_hooks_internal()
+    assert (
+        str(page).count(
+            'css={{"fakeParent": "parent", "color": "var(--plum-10)", "fake": "text", "margin": `${test_state.num}%`}}'
+        )
+        == 1
+    )
+
+
 def test_add_style_foreach():
     class StyledComponent(Component):
         tag = "StyledComponent"

+ 23 - 1
tests/test_style.py

@@ -8,7 +8,7 @@ import reflex as rx
 from reflex import style
 from reflex.components.component import evaluate_style_namespaces
 from reflex.style import Style
-from reflex.vars import Var
+from reflex.vars import Var, VarData
 
 test_style = [
     ({"a": 1}, {"a": 1}),
@@ -503,3 +503,25 @@ def test_evaluate_style_namespaces():
     assert rx.text.__call__ not in style_dict
     style_dict = evaluate_style_namespaces(style_dict)  # type: ignore
     assert rx.text.__call__ in style_dict
+
+
+def test_style_update_with_var_data():
+    """Test that .update with a Style containing VarData works."""
+    red_var = Var.create_safe("red")._replace(
+        merge_var_data=VarData(hooks={"const red = true": None}),  # type: ignore
+    )
+    blue_var = Var.create_safe("blue", _var_is_local=False)._replace(
+        merge_var_data=VarData(hooks={"const blue = true": None}),  # type: ignore
+    )
+
+    s1 = Style(
+        {
+            "color": red_var,
+        }
+    )
+    s2 = Style()
+    s2.update(s1, background_color=f"{blue_var}ish")
+    assert s2 == {"color": "red", "backgroundColor": "`${blue}ish`"}
+    assert s2._var_data is not None
+    assert "const red = true" in s2._var_data.hooks
+    assert "const blue = true" in s2._var_data.hooks