Browse Source

[REF-2802] Foreach should respect modifications to children (#3263)

* Unit tests for add_style and component styles with foreach

The styles should be correctly applied for components that are rendered as part
of a foreach.

* [REF-2802] Foreach should respect modifications to children

Components are mutable, and there is logic that depends on walking through the
component tree and making modifications to components along the way. These
modifications _must_ be respected by foreach for consistency.

Modifications necessary to fix the bug:

* Change the hash function in `_render` to get a hash over the render_fn's
  `__code__` object. This way we get a stable hash without having to call the
  render function with bogus values.
* Call the render function once during `create` and save the result as a child
  of the Foreach component (tree walks will modify this instance).
* Directly render the original (and possibly modified) child component instead
  of calling the render_fn again and creating a new component instance at
  render time.

Additional changes because they're nice:

* Deprecate passing `**props` to `rx.foreach`. No one should have been
  doing this anyway, because it just does not work in any reasonable way.
* Raise `ForeachVarError` when the iterable type is Any
* Raise `ForeachRenderError` when the render function does not take 1 or 2 args.
* Link to the foreach component docs when either of those errors are hit.
* Change the `iterable` arg in `create` to accept `Var[Iterable] | Iterable`
  for better typing support (and remove some type: ignore comments)
* Simplify `_render` and `render` methods -- remove unused and potentially
  confusing code.

* Fixup: `to_bytes` requires `byteorder` arg before py3.11
Masen Furer 1 year ago
parent
commit
d767dc5fc7

+ 59 - 33
reflex/components/core/foreach.py

@@ -2,16 +2,24 @@
 from __future__ import annotations
 
 import inspect
-from hashlib import md5
 from typing import Any, Callable, Iterable
 
 from reflex.components.base.fragment import Fragment
 from reflex.components.component import Component
 from reflex.components.tags import IterTag
 from reflex.constants import MemoizationMode
+from reflex.utils import console
 from reflex.vars import Var
 
 
+class ForeachVarError(TypeError):
+    """Raised when the iterable type is Any."""
+
+
+class ForeachRenderError(TypeError):
+    """Raised when there is an error with the foreach render function."""
+
+
 class Foreach(Component):
     """A component that takes in an iterable and a render function and renders a list of components."""
 
@@ -24,56 +32,84 @@ class Foreach(Component):
     render_fn: Callable = Fragment.create
 
     @classmethod
-    def create(cls, iterable: Var[Iterable], render_fn: Callable, **props) -> Foreach:
+    def create(
+        cls,
+        iterable: Var[Iterable] | Iterable,
+        render_fn: Callable,
+        **props,
+    ) -> Foreach:
         """Create a foreach component.
 
         Args:
             iterable: The iterable to create components from.
             render_fn: A function from the render args to the component.
-            **props: The attributes to pass to each child component.
+            **props: The attributes to pass to each child component (deprecated).
 
         Returns:
             The foreach component.
 
         Raises:
-            TypeError: If the iterable is of type Any.
+            ForeachVarError: If the iterable is of type Any.
         """
-        iterable = Var.create(iterable)  # type: ignore
+        if props:
+            console.deprecate(
+                feature_name="Passing props to rx.foreach",
+                reason="it does not have the intended effect and may be confusing",
+                deprecation_version="0.5.0",
+                removal_version="0.6.0",
+            )
+        iterable = Var.create_safe(iterable)
         if iterable._var_type == Any:
-            raise TypeError(
-                f"Could not foreach over var of type Any. (If you are trying to foreach over a state var, add a type annotation to the var.)"
+            raise ForeachVarError(
+                f"Could not foreach over var `{iterable._var_full_name}` of type Any. "
+                "(If you are trying to foreach over a state var, add a type annotation to the var). "
+                "See https://reflex.dev/docs/library/layout/foreach/"
             )
         component = cls(
             iterable=iterable,
             render_fn=render_fn,
-            **props,
         )
-        # Keep a ref to a rendered component to determine correct imports.
-        component.children = [
-            component._render(props=dict(index_var_name="i")).render_component()
-        ]
+        # Keep a ref to a rendered component to determine correct imports/hooks/styles.
+        component.children = [component._render().render_component()]
         return component
 
-    def _render(self, props: dict[str, Any] | None = None) -> IterTag:
-        props = {} if props is None else props.copy()
+    def _render(self) -> IterTag:
+        props = {}
 
-        # Determine the arg var name based on the params accepted by render_fn.
         render_sig = inspect.signature(self.render_fn)
         params = list(render_sig.parameters.values())
+
+        # Validate the render function signature.
+        if len(params) == 0 or len(params) > 2:
+            raise ForeachRenderError(
+                "Expected 1 or 2 parameters in foreach render function, got "
+                f"{[p.name for p in params]}. See https://reflex.dev/docs/library/layout/foreach/"
+            )
+
         if len(params) >= 1:
-            props.setdefault("arg_var_name", params[0].name)
+            # Determine the arg var name based on the params accepted by render_fn.
+            props["arg_var_name"] = params[0].name
 
-        if len(params) >= 2:
+        if len(params) == 2:
             # Determine the index var name based on the params accepted by render_fn.
-            props.setdefault("index_var_name", params[1].name)
-        elif "index_var_name" not in props:
-            # Otherwise, use a deterministic index, based on the rendered code.
-            code_hash = md5(str(self.children[0].render()).encode("utf-8")).hexdigest()
-            props.setdefault("index_var_name", f"index_{code_hash}")
+            props["index_var_name"] = params[1].name
+        else:
+            # Otherwise, use a deterministic index, based on the render function bytecode.
+            code_hash = (
+                hash(self.render_fn.__code__)
+                .to_bytes(
+                    length=8,
+                    byteorder="big",
+                    signed=True,
+                )
+                .hex()
+            )
+            props["index_var_name"] = f"index_{code_hash}"
 
         return IterTag(
             iterable=self.iterable,
             render_fn=self.render_fn,
+            children=self.children,
             **props,
         )
 
@@ -84,19 +120,9 @@ class Foreach(Component):
             The dictionary for template of component.
         """
         tag = self._render()
-        component = tag.render_component()
 
         return dict(
-            tag.add_props(
-                **self.event_triggers,
-                key=self.key,
-                sx=self.style,
-                id=self.id,
-                class_name=self.class_name,
-            ).set(
-                children=[component.render()],
-                props=tag.format_props(),
-            ),
+            tag,
             iterable_state=tag.iterable._var_full_name,
             arg_name=tag.arg_var_name,
             arg_index=tag.get_index_var_arg(),

+ 33 - 17
tests/components/core/test_foreach.py

@@ -2,17 +2,11 @@ from typing import Dict, List, Set, Tuple, Union
 
 import pytest
 
-from reflex.components import box, foreach, text
-from reflex.components.core import Foreach
+from reflex.components import box, el, foreach, text
+from reflex.components.core.foreach import Foreach, ForeachRenderError, ForeachVarError
 from reflex.state import BaseState
 from reflex.vars import Var
 
-try:
-    # When pydantic v2 is installed
-    from pydantic.v1 import ValidationError  # type: ignore
-except ImportError:
-    from pydantic import ValidationError
-
 
 class ForEachState(BaseState):
     """A state for testing the ForEach component."""
@@ -84,12 +78,12 @@ def display_nested_color_with_shades_v2(color):
 
 def display_color_tuple(color):
     assert color._var_type == str
-    return box(text(color, "tuple"))
+    return box(text(color))
 
 
 def display_colors_set(color):
     assert color._var_type == str
-    return box(text(color, "set"))
+    return box(text(color))
 
 
 def display_nested_list_element(element: Var[str], index: Var[int]):
@@ -100,7 +94,7 @@ def display_nested_list_element(element: Var[str], index: Var[int]):
 
 def display_color_index_tuple(color):
     assert color._var_type == Union[int, str]
-    return box(text(color, "index_tuple"))
+    return box(text(color))
 
 
 seen_index_vars = set()
@@ -215,24 +209,46 @@ def test_foreach_render(state_var, render_fn, render_dict):
 
     # Make sure the index vars are unique.
     arg_index = rend["arg_index"]
+    assert isinstance(arg_index, Var)
     assert arg_index._var_name not in seen_index_vars
     assert arg_index._var_type == int
     seen_index_vars.add(arg_index._var_name)
 
 
 def test_foreach_bad_annotations():
-    """Test that the foreach component raises a TypeError if the iterable is of type Any."""
-    with pytest.raises(TypeError):
+    """Test that the foreach component raises a ForeachVarError if the iterable is of type Any."""
+    with pytest.raises(ForeachVarError):
         Foreach.create(
-            ForEachState.bad_annotation_list,  # type: ignore
+            ForEachState.bad_annotation_list,
             lambda sublist: Foreach.create(sublist, lambda color: text(color)),
         )
 
 
 def test_foreach_no_param_in_signature():
-    """Test that the foreach component raises a TypeError if no parameters are passed."""
-    with pytest.raises(ValidationError):
+    """Test that the foreach component raises a ForeachRenderError if no parameters are passed."""
+    with pytest.raises(ForeachRenderError):
         Foreach.create(
-            ForEachState.colors_list,  # type: ignore
+            ForEachState.colors_list,
             lambda: text("color"),
         )
+
+
+def test_foreach_too_many_params_in_signature():
+    """Test that the foreach component raises a ForeachRenderError if too many parameters are passed."""
+    with pytest.raises(ForeachRenderError):
+        Foreach.create(
+            ForEachState.colors_list,
+            lambda color, index, extra: text(color),
+        )
+
+
+def test_foreach_component_styles():
+    """Test that the foreach component works with global component styles."""
+    component = el.div(
+        foreach(
+            ForEachState.colors_list,
+            display_color,
+        )
+    )
+    component._add_style_recursive({box: {"color": "red"}})
+    assert 'css={{"color": "red"}}' in str(component)

+ 21 - 0
tests/components/test_component.py

@@ -1951,3 +1951,24 @@ def test_component_add_custom_code():
         "const custom_code5 = 46",
         "const custom_code6 = 47",
     }
+
+
+def test_add_style_foreach():
+    class StyledComponent(Component):
+        tag = "StyledComponent"
+        ix: Var[int]
+
+        def add_style(self):
+            return Style({"color": "red"})
+
+    page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i)))
+    page._add_style_recursive(Style())
+
+    # Expect only a single child of the foreach on the python side
+    assert len(page.children[0].children) == 1
+
+    # Expect the style to be added to the child of the foreach
+    assert 'css={{"color": "red"}}' in str(page.children[0].children[0])
+
+    # Expect only one instance of this CSS dict in the rendered page
+    assert str(page).count('css={{"color": "red"}}') == 1