Ver código fonte

what if i deleted rx.foreach

Khaleel Al-Adhami 4 meses atrás
pai
commit
5d6b51c561

+ 27 - 2
reflex/components/base/bare.py

@@ -4,12 +4,12 @@ from __future__ import annotations
 
 
 from typing import Any, Iterator
 from typing import Any, Iterator
 
 
-from reflex.components.component import Component
+from reflex.components.component import Component, ComponentStyle
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
 from reflex.components.tags.tagless import Tagless
 from reflex.components.tags.tagless import Tagless
 from reflex.utils.imports import ParsedImportDict
 from reflex.utils.imports import ParsedImportDict
 from reflex.vars import BooleanVar, ObjectVar, Var
 from reflex.vars import BooleanVar, ObjectVar, Var
-from reflex.vars.base import VarData
+from reflex.vars.base import VarData, get_var_caching, set_var_caching
 
 
 
 
 class Bare(Component):
 class Bare(Component):
@@ -141,6 +141,31 @@ class Bare(Component):
             return Tagless(contents=f"{{{self.contents!s}}}")
             return Tagless(contents=f"{{{self.contents!s}}}")
         return Tagless(contents=str(self.contents))
         return Tagless(contents=str(self.contents))
 
 
+    def _add_style_recursive(
+        self, style: ComponentStyle, theme: Component | None = None
+    ) -> Component:
+        """Add style to the component and its children.
+
+        Args:
+            style: The style to add.
+            theme: The theme to add.
+
+        Returns:
+            The component with the style added.
+        """
+        new_self = super()._add_style_recursive(style, theme)
+        if isinstance(self.contents, Var):
+            var_data = self.contents._get_all_var_data()
+            if var_data:
+                for component in var_data.components:
+                    if isinstance(component, Component):
+                        component._add_style_recursive(style, theme)
+        if get_var_caching():
+            set_var_caching(False)
+            str(new_self)
+            set_var_caching(True)
+        return new_self
+
     def _get_vars(
     def _get_vars(
         self, include_children: bool = False, ignore_ids: set[int] | None = None
         self, include_children: bool = False, ignore_ids: set[int] | None = None
     ) -> Iterator[Var]:
     ) -> Iterator[Var]:

+ 1 - 11
reflex/components/component.py

@@ -931,7 +931,6 @@ class Component(BaseComponent, ABC):
         """
         """
         from reflex.components.base.bare import Bare
         from reflex.components.base.bare import Bare
         from reflex.components.base.fragment import Fragment
         from reflex.components.base.fragment import Fragment
-        from reflex.components.core.foreach import Foreach
 
 
         no_valid_parents_defined = all(child._valid_parents == [] for child in children)
         no_valid_parents_defined = all(child._valid_parents == [] for child in children)
         if (
         if (
@@ -942,7 +941,7 @@ class Component(BaseComponent, ABC):
             return
             return
 
 
         comp_name = type(self).__name__
         comp_name = type(self).__name__
-        allowed_components = [comp.__name__ for comp in (Fragment, Foreach)]
+        allowed_components = [comp.__name__ for comp in (Fragment,)]
 
 
         def validate_child(child):
         def validate_child(child):
             child_name = type(child).__name__
             child_name = type(child).__name__
@@ -1974,8 +1973,6 @@ class StatefulComponent(BaseComponent):
         Returns:
         Returns:
             The stateful component or None if the component should not be memoized.
             The stateful component or None if the component should not be memoized.
         """
         """
-        from reflex.components.core.foreach import Foreach
-
         if component._memoization_mode.disposition == MemoizationDisposition.NEVER:
         if component._memoization_mode.disposition == MemoizationDisposition.NEVER:
             # Never memoize this component.
             # Never memoize this component.
             return None
             return None
@@ -2004,10 +2001,6 @@ class StatefulComponent(BaseComponent):
                 # Skip BaseComponent and StatefulComponent children.
                 # Skip BaseComponent and StatefulComponent children.
                 if not isinstance(child, Component):
                 if not isinstance(child, Component):
                     continue
                     continue
-                # Always consider Foreach something that must be memoized by the parent.
-                if isinstance(child, Foreach):
-                    should_memoize = True
-                    break
                 child = cls._child_var(child)
                 child = cls._child_var(child)
                 if isinstance(child, Var) and child._get_all_var_data():
                 if isinstance(child, Var) and child._get_all_var_data():
                     should_memoize = True
                     should_memoize = True
@@ -2057,12 +2050,9 @@ class StatefulComponent(BaseComponent):
             The Var from the child component or the child itself (for regular cases).
             The Var from the child component or the child itself (for regular cases).
         """
         """
         from reflex.components.base.bare import Bare
         from reflex.components.base.bare import Bare
-        from reflex.components.core.foreach import Foreach
 
 
         if isinstance(child, Bare):
         if isinstance(child, Bare):
             return child.contents
             return child.contents
-        if isinstance(child, Foreach):
-            return child.iterable
         return child
         return child
 
 
     @classmethod
     @classmethod

+ 0 - 1
reflex/components/core/__init__.py

@@ -25,7 +25,6 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
     "debounce": ["DebounceInput", "debounce_input"],
     "debounce": ["DebounceInput", "debounce_input"],
     "foreach": [
     "foreach": [
         "foreach",
         "foreach",
-        "Foreach",
     ],
     ],
     "html": ["html", "Html"],
     "html": ["html", "Html"],
     "match": [
     "match": [

+ 0 - 1
reflex/components/core/__init__.pyi

@@ -21,7 +21,6 @@ from .cond import color_mode_cond as color_mode_cond
 from .cond import cond as cond
 from .cond import cond as cond
 from .debounce import DebounceInput as DebounceInput
 from .debounce import DebounceInput as DebounceInput
 from .debounce import debounce_input as debounce_input
 from .debounce import debounce_input as debounce_input
-from .foreach import Foreach as Foreach
 from .foreach import foreach as foreach
 from .foreach import foreach as foreach
 from .html import Html as Html
 from .html import Html as Html
 from .html import html as html
 from .html import html as html

+ 30 - 118
reflex/components/core/foreach.py

@@ -2,15 +2,11 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
-import inspect
-from typing import Any, Callable, Iterable
+from typing import Callable, Iterable
 
 
-from reflex.components.base.fragment import Fragment
-from reflex.components.component import Component
-from reflex.components.tags import IterTag
-from reflex.constants import MemoizationMode
-from reflex.state import ComponentState
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.base import LiteralVar, Var
+from reflex.vars.object import ObjectVar
+from reflex.vars.sequence import ArrayVar
 
 
 
 
 class ForeachVarError(TypeError):
 class ForeachVarError(TypeError):
@@ -21,116 +17,32 @@ class ForeachRenderError(TypeError):
     """Raised when there is an error with the foreach render function."""
     """Raised when there is an error with the foreach render function."""
 
 
 
 
-class Foreach(Component):
-    """A component that takes in an iterable and a render function and renders a list of components."""
-
-    _memoization_mode = MemoizationMode(recursive=False)
-
-    # The iterable to create components from.
-    iterable: Var[Iterable]
-
-    # A function from the render args to the component.
-    render_fn: Callable = Fragment.create
-
-    @classmethod
-    def create(
-        cls,
-        iterable: Var[Iterable] | Iterable,
-        render_fn: Callable,
-    ) -> Foreach:
-        """Create a foreach component.
-
-        Args:
-            iterable: The iterable to create components from.
-            render_fn: A function from the render args to the component.
-
-        Returns:
-            The foreach component.
-
-        Raises:
-            ForeachVarError: If the iterable is of type Any.
-            TypeError: If the render function is a ComponentState.
-        """
-        iterable = LiteralVar.create(iterable)
-        if iterable._var_type == Any:
-            raise ForeachVarError(
-                f"Could not foreach over var `{iterable!s}` of type Any. "
-                "(If you are trying to foreach over a state var, add a type annotation to the var). "
-                "See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
-            )
-
-        if (
-            hasattr(render_fn, "__qualname__")
-            and render_fn.__qualname__ == ComponentState.create.__qualname__
-        ):
-            raise TypeError(
-                "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet."
-            )
-
-        component = cls(
-            iterable=iterable,
-            render_fn=render_fn,
+def foreach(
+    iterable: Var[Iterable] | Iterable,
+    render_fn: Callable,
+) -> Var:
+    """Create a foreach component.
+
+    Args:
+        iterable: The iterable to create components from.
+        render_fn: A function from the render args to the component.
+
+    Returns:
+        The foreach component.
+
+    Raises:
+        ForeachVarError: If the iterable is of type Any.
+        TypeError: If the render function is a ComponentState.
+    """
+    iterable = LiteralVar.create(iterable)
+    if isinstance(iterable, ObjectVar):
+        iterable = iterable.items()
+
+    if not isinstance(iterable, ArrayVar):
+        raise ForeachVarError(
+            f"Could not foreach over var `{iterable!s}` of type {iterable._var_type!s}. "
+            "(If you are trying to foreach over a state var, add a type annotation to the var). "
+            "See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
         )
         )
-        # Keep a ref to a rendered component to determine correct imports/hooks/styles.
-        component.children = [component._render().render_component()]
-        return component
-
-    def _render(self) -> IterTag:
-        props = {}
-
-        render_sig = inspect.signature(self.render_fn)
-        params = list(render_sig.parameters.values())
-
-        # Validate the render function signature.
-        if len(params) == 0 or len(params) > 2:
-            raise ForeachRenderError(
-                "Expected 1 or 2 parameters in foreach render function, got "
-                f"{[p.name for p in params]}. See "
-                "https://reflex.dev/docs/library/dynamic-rendering/foreach/"
-            )
-
-        if len(params) >= 1:
-            # Determine the arg var name based on the params accepted by render_fn.
-            props["arg_var_name"] = params[0].name
-
-        if len(params) == 2:
-            # Determine the index var name based on the params accepted by render_fn.
-            props["index_var_name"] = params[1].name
-        else:
-            # Otherwise, use a deterministic index, based on the render function bytecode.
-            code_hash = (
-                hash(self.render_fn.__code__)
-                .to_bytes(
-                    length=8,
-                    byteorder="big",
-                    signed=True,
-                )
-                .hex()
-            )
-            props["index_var_name"] = f"index_{code_hash}"
-
-        return IterTag(
-            iterable=self.iterable,
-            render_fn=self.render_fn,
-            children=self.children,
-            **props,
-        )
-
-    def render(self):
-        """Render the component.
-
-        Returns:
-            The dictionary for template of component.
-        """
-        tag = self._render()
-
-        return dict(
-            tag,
-            iterable_state=str(tag.iterable),
-            arg_name=tag.arg_var_name,
-            arg_index=tag.get_index_var_arg(),
-            iterable_type=tag.iterable._var_type.mro()[0].__name__,
-        )
-
 
 
-foreach = Foreach.create
+    return iterable.foreach(render_fn)

+ 1 - 1
reflex/components/radix/primitives/slider.py

@@ -188,7 +188,7 @@ class Slider(ComponentNamespace):
         else:
         else:
             children = [
             children = [
                 track,
                 track,
-                #     Foreach.create(props.get("value"), lambda e: SliderThumb.create()),  # foreach doesn't render Thumbs properly # noqa: ERA001
+                #     foreach(props.get("value"), lambda e: SliderThumb.create()),  # foreach doesn't render Thumbs properly # noqa: ERA001
             ]
             ]
 
 
         return SliderRoot.create(*children, **props)
         return SliderRoot.create(*children, **props)

+ 2 - 2
reflex/components/radix/themes/layout/list.py

@@ -5,7 +5,7 @@ from __future__ import annotations
 from typing import Any, Iterable, Literal, Union
 from typing import Any, Iterable, Literal, Union
 
 
 from reflex.components.component import Component, ComponentNamespace
 from reflex.components.component import Component, ComponentNamespace
-from reflex.components.core.foreach import Foreach
+from reflex.components.core.foreach import foreach
 from reflex.components.el.elements.typography import Li, Ol, Ul
 from reflex.components.el.elements.typography import Li, Ol, Ul
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
 from reflex.components.markdown.markdown import MarkdownComponentMap
 from reflex.components.markdown.markdown import MarkdownComponentMap
@@ -70,7 +70,7 @@ class BaseList(Component, MarkdownComponentMap):
 
 
         if not children and items is not None:
         if not children and items is not None:
             if isinstance(items, Var):
             if isinstance(items, Var):
-                children = [Foreach.create(items, ListItem.create)]
+                children = [foreach(items, ListItem.create)]
             else:
             else:
                 children = [ListItem.create(item) for item in items]  # type: ignore
                 children = [ListItem.create(item) for item in items]  # type: ignore
         props["direction"] = "column"
         props["direction"] = "column"

+ 0 - 1
reflex/components/tags/__init__.py

@@ -1,4 +1,3 @@
 """Representations for React tags."""
 """Representations for React tags."""
 
 
-from .iter_tag import IterTag
 from .tag import Tag
 from .tag import Tag

+ 0 - 141
reflex/components/tags/iter_tag.py

@@ -1,141 +0,0 @@
-"""Tag to loop through a list of components."""
-
-from __future__ import annotations
-
-import dataclasses
-import inspect
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args
-
-from reflex.components.tags.tag import Tag
-from reflex.utils import types
-from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
-
-if TYPE_CHECKING:
-    from reflex.components.component import Component
-
-
-@dataclasses.dataclass()
-class IterTag(Tag):
-    """An iterator tag."""
-
-    # The var to iterate over.
-    iterable: Var[Iterable] = dataclasses.field(
-        default_factory=lambda: LiteralArrayVar.create([])
-    )
-
-    # The component render function for each item in the iterable.
-    render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x)
-
-    # The name of the arg var.
-    arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
-
-    # The name of the index var.
-    index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
-
-    def get_iterable_var_type(self) -> types.GenericType:
-        """Get the type of the iterable var.
-
-        Returns:
-            The type of the iterable var.
-        """
-        iterable = self.iterable
-        try:
-            if iterable._var_type.mro()[0] is dict:
-                # Arg is a tuple of (key, value).
-                return Tuple[get_args(iterable._var_type)]
-            elif iterable._var_type.mro()[0] is tuple:
-                # Arg is a union of any possible values in the tuple.
-                return Union[get_args(iterable._var_type)]
-            else:
-                return get_args(iterable._var_type)[0]
-        except Exception:
-            return Any
-
-    def get_index_var(self) -> Var:
-        """Get the index var for the tag (with curly braces).
-
-        This is used to reference the index var within the tag.
-
-        Returns:
-            The index var.
-        """
-        return Var(
-            _js_expr=self.index_var_name,
-            _var_type=int,
-        ).guess_type()
-
-    def get_arg_var(self) -> Var:
-        """Get the arg var for the tag (with curly braces).
-
-        This is used to reference the arg var within the tag.
-
-        Returns:
-            The arg var.
-        """
-        return Var(
-            _js_expr=self.arg_var_name,
-            _var_type=self.get_iterable_var_type(),
-        ).guess_type()
-
-    def get_index_var_arg(self) -> Var:
-        """Get the index var for the tag (without curly braces).
-
-        This is used to render the index var in the .map() function.
-
-        Returns:
-            The index var.
-        """
-        return Var(
-            _js_expr=self.index_var_name,
-            _var_type=int,
-        ).guess_type()
-
-    def get_arg_var_arg(self) -> Var:
-        """Get the arg var for the tag (without curly braces).
-
-        This is used to render the arg var in the .map() function.
-
-        Returns:
-            The arg var.
-        """
-        return Var(
-            _js_expr=self.arg_var_name,
-            _var_type=self.get_iterable_var_type(),
-        ).guess_type()
-
-    def render_component(self) -> Component:
-        """Render the component.
-
-        Raises:
-            ValueError: If the render function takes more than 2 arguments.
-
-        Returns:
-            The rendered component.
-        """
-        # Import here to avoid circular imports.
-        from reflex.components.base.fragment import Fragment
-        from reflex.components.core.foreach import Foreach
-
-        # Get the render function arguments.
-        args = inspect.getfullargspec(self.render_fn).args
-        arg = self.get_arg_var()
-        index = self.get_index_var()
-
-        if len(args) == 1:
-            # If the render function doesn't take the index as an argument.
-            component = self.render_fn(arg)
-        else:
-            # If the render function takes the index as an argument.
-            if len(args) != 2:
-                raise ValueError("The render function must take 2 arguments.")
-            component = self.render_fn(arg, index)
-
-        # Nested foreach components or cond must be wrapped in fragments.
-        if isinstance(component, (Foreach, Var)):
-            component = Fragment.create(component)
-
-        # Set the component key.
-        if component.key is None:
-            component.key = index
-
-        return component

+ 11 - 0
reflex/utils/types.py

@@ -890,12 +890,23 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
     Returns:
     Returns:
         Whether the type hint is a subclass of the other type hint.
         Whether the type hint is a subclass of the other type hint.
     """
     """
+    if isinstance(possible_subclass, Sequence) and isinstance(
+        possible_superclass, Sequence
+    ):
+        return all(
+            typehint_issubclass(subclass, superclass)
+            for subclass, superclass in zip(possible_subclass, possible_superclass)
+        )
     if possible_subclass is possible_superclass:
     if possible_subclass is possible_superclass:
         return True
         return True
     if possible_superclass is Any:
     if possible_superclass is Any:
         return True
         return True
     if possible_subclass is Any:
     if possible_subclass is Any:
         return False
         return False
+    if isinstance(
+        possible_subclass, (TypeVar, typing_extensions.TypeVar)
+    ) or isinstance(possible_superclass, (TypeVar, typing_extensions.TypeVar)):
+        return True
 
 
     provided_type_origin = get_origin(possible_subclass)
     provided_type_origin = get_origin(possible_subclass)
     accepted_type_origin = get_origin(possible_superclass)
     accepted_type_origin = get_origin(possible_superclass)

+ 41 - 0
reflex/vars/base.py

@@ -151,6 +151,28 @@ def unwrap_reflex_callalbe(
     return args
     return args
 
 
 
 
+_VAR_CACHING = True
+
+
+def get_var_caching() -> bool:
+    """Get the var caching status.
+
+    Returns:
+        The var caching status.
+    """
+    return _VAR_CACHING
+
+
+def set_var_caching(value: bool):
+    """Set the var caching status.
+
+    Args:
+        value: The value to set the var caching status to.
+    """
+    global _VAR_CACHING
+    _VAR_CACHING = value
+
+
 @dataclasses.dataclass(
 @dataclasses.dataclass(
     eq=False,
     eq=False,
     frozen=True,
     frozen=True,
@@ -1186,6 +1208,25 @@ class Var(Generic[VAR_TYPE]):
         """
         """
         return self
         return self
 
 
+    def __getattribute__(self, name: str) -> Any:
+        """Get an attribute of the var.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The attribute.
+        """
+        if not _VAR_CACHING:
+            try:
+                self_dict = object.__getattribute__(self, "__dict__")
+                for key in self_dict:
+                    if key.startswith("_cached_"):
+                        del self_dict[key]
+            except Exception:
+                pass
+        return super().__getattribute__(name)
+
     def __getattr__(self, name: str):
     def __getattr__(self, name: str):
         """Get an attribute of the var.
         """Get an attribute of the var.
 
 

+ 55 - 9
reflex/vars/sequence.py

@@ -741,7 +741,8 @@ if TYPE_CHECKING:
 def map_array_operation(
 def map_array_operation(
     array: Var[Sequence[INNER_ARRAY_VAR]],
     array: Var[Sequence[INNER_ARRAY_VAR]],
     function: Var[
     function: Var[
-        ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]
+        ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR]
+        | ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]
         | ReflexCallable[[], ANOTHER_ARRAY_VAR]
         | ReflexCallable[[], ANOTHER_ARRAY_VAR]
     ],
     ],
 ) -> CustomVarOperationReturn[Sequence[ANOTHER_ARRAY_VAR]]:
 ) -> CustomVarOperationReturn[Sequence[ANOTHER_ARRAY_VAR]]:
@@ -973,7 +974,8 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)):
 
 
     def foreach(
     def foreach(
         self: ArrayVar[Sequence[INNER_ARRAY_VAR]],
         self: ArrayVar[Sequence[INNER_ARRAY_VAR]],
-        fn: Callable[[Var[INNER_ARRAY_VAR]], ANOTHER_ARRAY_VAR]
+        fn: Callable[[Var[INNER_ARRAY_VAR], NumberVar[int]], ANOTHER_ARRAY_VAR]
+        | Callable[[Var[INNER_ARRAY_VAR]], ANOTHER_ARRAY_VAR]
         | Callable[[], ANOTHER_ARRAY_VAR],
         | Callable[[], ANOTHER_ARRAY_VAR],
     ) -> ArrayVar[Sequence[ANOTHER_ARRAY_VAR]]:
     ) -> ArrayVar[Sequence[ANOTHER_ARRAY_VAR]]:
         """Apply a function to each element of the array.
         """Apply a function to each element of the array.
@@ -987,21 +989,36 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)):
         Raises:
         Raises:
             VarTypeError: If the function takes more than one argument.
             VarTypeError: If the function takes more than one argument.
         """
         """
+        from reflex.state import ComponentState
+
         from .function import ArgsFunctionOperation
         from .function import ArgsFunctionOperation
 
 
         if not callable(fn):
         if not callable(fn):
             raise_unsupported_operand_types("foreach", (type(self), type(fn)))
             raise_unsupported_operand_types("foreach", (type(self), type(fn)))
         # get the number of arguments of the function
         # get the number of arguments of the function
         num_args = len(inspect.signature(fn).parameters)
         num_args = len(inspect.signature(fn).parameters)
-        if num_args > 1:
+        if num_args > 2:
             raise VarTypeError(
             raise VarTypeError(
-                "The function passed to foreach should take at most one argument."
+                "The function passed to foreach should take at most two arguments."
+            )
+
+        if (
+            hasattr(fn, "__qualname__")
+            and fn.__qualname__ == ComponentState.create.__qualname__
+        ):
+            raise TypeError(
+                "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet."
             )
             )
 
 
         if num_args == 0:
         if num_args == 0:
-            return_value = fn()  # type: ignore
+            fn_result = fn()  # pyright: ignore [reportCallIssue]
+            return_value = Var.create(fn_result)
             simple_function_var: FunctionVar[ReflexCallable[[], ANOTHER_ARRAY_VAR]] = (
             simple_function_var: FunctionVar[ReflexCallable[[], ANOTHER_ARRAY_VAR]] = (
-                ArgsFunctionOperation.create((), return_value)
+                ArgsFunctionOperation.create(
+                    (),
+                    return_value,
+                    _var_type=ReflexCallable[[], return_value._var_type],
+                )
             )
             )
             return map_array_operation(self, simple_function_var).guess_type()
             return map_array_operation(self, simple_function_var).guess_type()
 
 
@@ -1021,11 +1038,40 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)):
             ).guess_type(),
             ).guess_type(),
         )
         )
 
 
+        if num_args == 1:
+            fn_result = fn(first_arg)  # pyright: ignore [reportCallIssue]
+
+            return_value = Var.create(fn_result)
+
+            function_var = cast(
+                Var[ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]],
+                ArgsFunctionOperation.create(
+                    (arg_name,),
+                    return_value,
+                    _var_type=ReflexCallable[[first_arg_type], return_value._var_type],
+                ),
+            )
+
+            return map_array_operation.call(self, function_var).guess_type()
+
+        second_arg = cast(
+            NumberVar[int],
+            Var(
+                _js_expr=get_unique_variable_name(),
+                _var_type=int,
+            ).guess_type(),
+        )
+
+        fn_result = fn(first_arg, second_arg)  # pyright: ignore [reportCallIssue]
+
+        return_value = Var.create(fn_result)
+
         function_var = cast(
         function_var = cast(
-            Var[ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]],
+            Var[ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR]],
             ArgsFunctionOperation.create(
             ArgsFunctionOperation.create(
-                (arg_name,),
-                Var.create(fn(first_arg)),  # type: ignore
+                (arg_name, second_arg._js_expr),
+                return_value,
+                _var_type=ReflexCallable[[first_arg_type, int], return_value._var_type],
             ),
             ),
         )
         )
 
 

+ 22 - 135
tests/units/components/core/test_foreach.py

@@ -6,16 +6,11 @@ import pytest
 from reflex import el
 from reflex import el
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.component import Component
 from reflex.components.component import Component
-from reflex.components.core.foreach import (
-    Foreach,
-    ForeachRenderError,
-    ForeachVarError,
-    foreach,
-)
+from reflex.components.core.foreach import ForeachVarError, foreach
 from reflex.components.radix.themes.layout.box import box
 from reflex.components.radix.themes.layout.box import box
 from reflex.components.radix.themes.typography.text import text
 from reflex.components.radix.themes.typography.text import text
 from reflex.state import BaseState, ComponentState
 from reflex.state import BaseState, ComponentState
-from reflex.vars.base import Var
+from reflex.utils.exceptions import VarTypeError
 from reflex.vars.number import NumberVar
 from reflex.vars.number import NumberVar
 from reflex.vars.sequence import ArrayVar
 from reflex.vars.sequence import ArrayVar
 
 
@@ -141,143 +136,35 @@ def display_color_index_tuple(color):
 seen_index_vars = set()
 seen_index_vars = set()
 
 
 
 
-@pytest.mark.parametrize(
-    "state_var, render_fn, render_dict",
-    [
-        (
-            ForEachState.colors_list,
-            display_color,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.colors_list",
-                "iterable_type": "list",
-            },
-        ),
-        (
-            ForEachState.colors_dict_list,
-            display_color_name,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list",
-                "iterable_type": "list",
-            },
-        ),
-        (
-            ForEachState.colors_nested_dict_list,
-            display_shade,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list",
-                "iterable_type": "list",
-            },
-        ),
-        (
-            ForEachState.primary_color,
-            display_primary_colors,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.primary_color",
-                "iterable_type": "dict",
-            },
-        ),
-        (
-            ForEachState.color_with_shades,
-            display_color_with_shades,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.color_with_shades",
-                "iterable_type": "dict",
-            },
-        ),
-        (
-            ForEachState.nested_colors_with_shades,
-            display_nested_color_with_shades,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades",
-                "iterable_type": "dict",
-            },
-        ),
-        (
-            ForEachState.nested_colors_with_shades,
-            display_nested_color_with_shades_v2,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades",
-                "iterable_type": "dict",
-            },
-        ),
-        (
-            ForEachState.color_tuple,
-            display_color_tuple,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.color_tuple",
-                "iterable_type": "tuple",
-            },
-        ),
-        (
-            ForEachState.colors_set,
-            display_colors_set,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.colors_set",
-                "iterable_type": "set",
-            },
-        ),
-        (
-            ForEachState.nested_colors_list,
-            lambda el, i: display_nested_list_element(el, i),
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list",
-                "iterable_type": "list",
-            },
-        ),
-        (
-            ForEachState.color_index_tuple,
-            display_color_index_tuple,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple",
-                "iterable_type": "tuple",
-            },
-        ),
-    ],
-)
-def test_foreach_render(state_var, render_fn, render_dict):
-    """Test that the foreach component renders without error.
-
-    Args:
-        state_var: the state var.
-        render_fn: The render callable
-        render_dict: return dict on calling `component.render`
-    """
-    component = Foreach.create(state_var, render_fn)
-
-    rend = component.render()
-    assert rend["iterable_state"] == render_dict["iterable_state"]
-    assert rend["iterable_type"] == render_dict["iterable_type"]
-
-    # Make sure the index vars are unique.
-    arg_index = rend["arg_index"]
-    assert isinstance(arg_index, Var)
-    assert arg_index._js_expr not in seen_index_vars
-    assert arg_index._var_type is int
-    seen_index_vars.add(arg_index._js_expr)
-
-
 def test_foreach_bad_annotations():
 def test_foreach_bad_annotations():
     """Test that the foreach component raises a ForeachVarError if the iterable is of type Any."""
     """Test that the foreach component raises a ForeachVarError if the iterable is of type Any."""
     with pytest.raises(ForeachVarError):
     with pytest.raises(ForeachVarError):
-        Foreach.create(
+        foreach(
             ForEachState.bad_annotation_list,
             ForEachState.bad_annotation_list,
-            lambda sublist: Foreach.create(sublist, lambda color: text(color)),
+            lambda sublist: foreach(sublist, lambda color: text(color)),
         )
         )
 
 
 
 
 def test_foreach_no_param_in_signature():
 def test_foreach_no_param_in_signature():
-    """Test that the foreach component raises a ForeachRenderError if no parameters are passed."""
-    with pytest.raises(ForeachRenderError):
-        Foreach.create(
-            ForEachState.colors_list,
-            lambda: text("color"),
-        )
+    """Test that the foreach component DOES NOT raise an error if no parameters are passed."""
+    foreach(
+        ForEachState.colors_list,
+        lambda: text("color"),
+    )
+
+
+def test_foreach_with_index():
+    """Test that the foreach component works with an index."""
+    foreach(
+        ForEachState.colors_list,
+        lambda color, index: text(color, index),
+    )
 
 
 
 
 def test_foreach_too_many_params_in_signature():
 def test_foreach_too_many_params_in_signature():
     """Test that the foreach component raises a ForeachRenderError if too many parameters are passed."""
     """Test that the foreach component raises a ForeachRenderError if too many parameters are passed."""
-    with pytest.raises(ForeachRenderError):
-        Foreach.create(
+    with pytest.raises(VarTypeError):
+        foreach(
             ForEachState.colors_list,
             ForEachState.colors_list,
             lambda color, index, extra: text(color),
             lambda color, index, extra: text(color),
         )
         )
@@ -292,13 +179,13 @@ def test_foreach_component_styles():
         )
         )
     )
     )
     component._add_style_recursive({box: {"color": "red"}})
     component._add_style_recursive({box: {"color": "red"}})
-    assert 'css={({ ["color"] : "red" })}' in str(component)
+    assert '{ ["css"] : ({ ["color"] : "red" }) }' in str(component)
 
 
 
 
 def test_foreach_component_state():
 def test_foreach_component_state():
     """Test that using a component state to render in the foreach raises an error."""
     """Test that using a component state to render in the foreach raises an error."""
     with pytest.raises(TypeError):
     with pytest.raises(TypeError):
-        Foreach.create(
+        foreach(
             ForEachState.colors_list,
             ForEachState.colors_list,
             ComponentStateTest.create,
             ComponentStateTest.create,
         )
         )
@@ -306,7 +193,7 @@ def test_foreach_component_state():
 
 
 def test_foreach_default_factory():
 def test_foreach_default_factory():
     """Test that the default factory is called."""
     """Test that the default factory is called."""
-    _ = Foreach.create(
+    _ = foreach(
         ForEachState.default_factory_list,
         ForEachState.default_factory_list,
         lambda tag: text(tag.name),
         lambda tag: text(tag.name),
     )
     )

+ 2 - 6
tests/units/components/test_component.py

@@ -1446,7 +1446,6 @@ def test_instantiate_all_components():
     untested_components = {
     untested_components = {
         "Card",
         "Card",
         "DebounceInput",
         "DebounceInput",
-        "Foreach",
         "FormControl",
         "FormControl",
         "Html",
         "Html",
         "Icon",
         "Icon",
@@ -2147,14 +2146,11 @@ def test_add_style_foreach():
     page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i)))
     page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i)))
     page._add_style_recursive(Style())
     page._add_style_recursive(Style())
 
 
-    # Expect only a single child of the foreach on the python side
-    assert len(page.children[0].children) == 1
-
     # Expect the style to be added to the child of the foreach
     # Expect the style to be added to the child of the foreach
-    assert 'css={({ ["color"] : "red" })}' in str(page.children[0].children[0])
+    assert '({ ["css"] : ({ ["color"] : "red" }) }),' in str(page.children[0])
 
 
     # Expect only one instance of this CSS dict in the rendered page
     # Expect only one instance of this CSS dict in the rendered page
-    assert str(page).count('css={({ ["color"] : "red" })}') == 1
+    assert str(page).count('({ ["css"] : ({ ["color"] : "red" }) }),') == 1
 
 
 
 
 class TriggerState(rx.State):
 class TriggerState(rx.State):