Forráskód Böngészése

Set unique index vars in rx.foreach (#2126)

Nikhil Rao 1 éve
szülő
commit
e703d87450

+ 12 - 6
reflex/components/layout/foreach.py

@@ -1,6 +1,7 @@
 """Create a list of components from an iterable."""
 from __future__ import annotations
 
+import typing
 from typing import Any, Callable, Iterable
 
 from reflex.components.component import Component
@@ -47,15 +48,20 @@ class Foreach(Component):
                 f"Could not foreach over var of type Any. (If you are trying to foreach over a state var, add a type annotation to the var.)"
             )
         arg = BaseVar(_var_name="_", _var_type=type_, _var_is_local=True)
+        comp = IterTag(iterable=iterable, render_fn=render_fn).render_component(arg)
         return cls(
             iterable=iterable,
             render_fn=render_fn,
-            children=[IterTag.render_component(render_fn, arg=arg)],
+            children=[comp],
             **props,
         )
 
     def _render(self) -> IterTag:
-        return IterTag(iterable=self.iterable, render_fn=self.render_fn)
+        return IterTag(
+            iterable=self.iterable,
+            render_fn=self.render_fn,
+            index_var_name=get_unique_variable_name(),
+        )
 
     def render(self):
         """Render the component.
@@ -66,9 +72,9 @@ class Foreach(Component):
         tag = self._render()
         try:
             type_ = (
-                self.iterable._var_type
-                if self.iterable._var_type.mro()[0] == dict
-                else self.iterable._var_type.__args__[0]
+                tag.iterable._var_type
+                if tag.iterable._var_type.mro()[0] == dict
+                else typing.get_args(tag.iterable._var_type)[0]
             )
         except Exception:
             type_ = Any
@@ -77,7 +83,7 @@ class Foreach(Component):
             _var_type=type_,
         )
         index_arg = tag.get_index_var_arg()
-        component = tag.render_component(self.render_fn, arg)
+        component = tag.render_component(arg)
         return dict(
             tag.add_props(
                 **self.event_triggers,

+ 18 - 18
reflex/components/tags/iter_tag.py

@@ -11,9 +11,6 @@ if TYPE_CHECKING:
     from reflex.components.component import Component
 
 
-INDEX_VAR = "i"
-
-
 class IterTag(Tag):
     """An iterator tag."""
 
@@ -23,37 +20,40 @@ class IterTag(Tag):
     # The component render function for each item in the iterable.
     render_fn: Callable
 
-    @staticmethod
-    def get_index_var() -> Var:
-        """Get the index var for the tag.
+    # The name of the index var.
+    index_var_name: str = "i"
+
+    def get_index_var(self) -> Var:
+        """Get the index var for the tag (with curly braces).
+
+        This is used to reference the index var within the tag.
 
         Returns:
             The index var.
         """
         return BaseVar(
-            _var_name=INDEX_VAR,
+            _var_name=self.index_var_name,
             _var_type=int,
         )
 
-    @staticmethod
-    def get_index_var_arg() -> Var:
-        """Get the index var for the tag.
+    def get_index_var_arg(self) -> Var:
+        """Get the index var for the tag (without curly braces).
+
+        This is used to render the index var in the .map() function.
 
         Returns:
             The index var.
         """
         return BaseVar(
-            _var_name=INDEX_VAR,
+            _var_name=self.index_var_name,
             _var_type=int,
             _var_is_local=True,
         )
 
-    @staticmethod
-    def render_component(render_fn: Callable, arg: Var) -> Component:
+    def render_component(self, arg: Var) -> Component:
         """Render the component.
 
         Args:
-            render_fn: The render function.
             arg: The argument to pass to the render function.
 
         Returns:
@@ -65,16 +65,16 @@ class IterTag(Tag):
         from reflex.components.layout.fragment import Fragment
 
         # Get the render function arguments.
-        args = inspect.getfullargspec(render_fn).args
-        index = IterTag.get_index_var()
+        args = inspect.getfullargspec(self.render_fn).args
+        index = self.get_index_var()
 
         if len(args) == 1:
             # If the render function doesn't take the index as an argument.
-            component = render_fn(arg)
+            component = self.render_fn(arg)
         else:
             # If the render function takes the index as an argument.
             assert len(args) == 2
-            component = render_fn(arg, index)
+            component = self.render_fn(arg, index)
 
         # Nested foreach components or cond must be wrapped in fragments.
         if isinstance(component, (Foreach, Cond)):

+ 9 - 13
tests/components/layout/test_foreach.py

@@ -78,6 +78,9 @@ def display_nested_list_element(element: str, index: int):
     return box(text(element[index]))
 
 
+seen_index_vars = set()
+
+
 @pytest.mark.parametrize(
     "state_var, render_fn, render_dict",
     [
@@ -86,7 +89,6 @@ def display_nested_list_element(element: str, index: int):
             display_color,
             {
                 "iterable_state": "for_each_state.colors_list",
-                "arg_index": "i",
                 "iterable_type": "list",
             },
         ),
@@ -95,7 +97,6 @@ def display_nested_list_element(element: str, index: int):
             display_color_name,
             {
                 "iterable_state": "for_each_state.colors_dict_list",
-                "arg_index": "i",
                 "iterable_type": "list",
             },
         ),
@@ -104,7 +105,6 @@ def display_nested_list_element(element: str, index: int):
             display_shade,
             {
                 "iterable_state": "for_each_state.colors_nested_dict_list",
-                "arg_index": "i",
                 "iterable_type": "list",
             },
         ),
@@ -113,7 +113,6 @@ def display_nested_list_element(element: str, index: int):
             display_primary_colors,
             {
                 "iterable_state": "for_each_state.primary_color",
-                "arg_index": "i",
                 "iterable_type": "dict",
             },
         ),
@@ -122,7 +121,6 @@ def display_nested_list_element(element: str, index: int):
             display_color_with_shades,
             {
                 "iterable_state": "for_each_state.color_with_shades",
-                "arg_index": "i",
                 "iterable_type": "dict",
             },
         ),
@@ -131,7 +129,6 @@ def display_nested_list_element(element: str, index: int):
             display_nested_color_with_shades,
             {
                 "iterable_state": "for_each_state.nested_colors_with_shades",
-                "arg_index": "i",
                 "iterable_type": "dict",
             },
         ),
@@ -140,7 +137,6 @@ def display_nested_list_element(element: str, index: int):
             display_nested_color_with_shades_v2,
             {
                 "iterable_state": "for_each_state.nested_colors_with_shades",
-                "arg_index": "i",
                 "iterable_type": "dict",
             },
         ),
@@ -149,7 +145,6 @@ def display_nested_list_element(element: str, index: int):
             display_color_tuple,
             {
                 "iterable_state": "for_each_state.color_tuple",
-                "arg_index": "i",
                 "iterable_type": "tuple",
             },
         ),
@@ -158,7 +153,6 @@ def display_nested_list_element(element: str, index: int):
             display_colors_set,
             {
                 "iterable_state": "for_each_state.colors_set",
-                "arg_index": "i",
                 "iterable_type": "set",
             },
         ),
@@ -167,7 +161,6 @@ def display_nested_list_element(element: str, index: int):
             lambda el, i: display_nested_list_element(el, i),
             {
                 "iterable_state": "for_each_state.nested_colors_list",
-                "arg_index": "i",
                 "iterable_type": "list",
             },
         ),
@@ -184,8 +177,11 @@ def test_foreach_render(state_var, render_fn, render_dict):
     component = Foreach.create(state_var, render_fn)
 
     rend = component.render()
-    arg_index = rend["arg_index"]
     assert rend["iterable_state"] == render_dict["iterable_state"]
-    assert arg_index._var_name == render_dict["arg_index"]
-    assert arg_index._var_type == int
     assert rend["iterable_type"] == render_dict["iterable_type"]
+
+    # Make sure the index vars are unique.
+    arg_index = rend["arg_index"]
+    assert arg_index._var_name not in seen_index_vars
+    assert arg_index._var_type == int
+    seen_index_vars.add(arg_index._var_name)