소스 검색

what if i deleted rx.foreach

Khaleel Al-Adhami 4 달 전
부모
커밋
5d6b51c561

+ 27 - 2
reflex/components/base/bare.py

@@ -4,12 +4,12 @@ from __future__ import annotations
 
 from typing import Any, Iterator
 
-from reflex.components.component import Component
+from reflex.components.component import Component, ComponentStyle
 from reflex.components.tags import Tag
 from reflex.components.tags.tagless import Tagless
 from reflex.utils.imports import ParsedImportDict
 from reflex.vars import BooleanVar, ObjectVar, Var
-from reflex.vars.base import VarData
+from reflex.vars.base import VarData, get_var_caching, set_var_caching
 
 
 class Bare(Component):
@@ -141,6 +141,31 @@ class Bare(Component):
             return Tagless(contents=f"{{{self.contents!s}}}")
         return Tagless(contents=str(self.contents))
 
+    def _add_style_recursive(
+        self, style: ComponentStyle, theme: Component | None = None
+    ) -> Component:
+        """Add style to the component and its children.
+
+        Args:
+            style: The style to add.
+            theme: The theme to add.
+
+        Returns:
+            The component with the style added.
+        """
+        new_self = super()._add_style_recursive(style, theme)
+        if isinstance(self.contents, Var):
+            var_data = self.contents._get_all_var_data()
+            if var_data:
+                for component in var_data.components:
+                    if isinstance(component, Component):
+                        component._add_style_recursive(style, theme)
+        if get_var_caching():
+            set_var_caching(False)
+            str(new_self)
+            set_var_caching(True)
+        return new_self
+
     def _get_vars(
         self, include_children: bool = False, ignore_ids: set[int] | None = None
     ) -> Iterator[Var]:

+ 1 - 11
reflex/components/component.py

@@ -931,7 +931,6 @@ class Component(BaseComponent, ABC):
         """
         from reflex.components.base.bare import Bare
         from reflex.components.base.fragment import Fragment
-        from reflex.components.core.foreach import Foreach
 
         no_valid_parents_defined = all(child._valid_parents == [] for child in children)
         if (
@@ -942,7 +941,7 @@ class Component(BaseComponent, ABC):
             return
 
         comp_name = type(self).__name__
-        allowed_components = [comp.__name__ for comp in (Fragment, Foreach)]
+        allowed_components = [comp.__name__ for comp in (Fragment,)]
 
         def validate_child(child):
             child_name = type(child).__name__
@@ -1974,8 +1973,6 @@ class StatefulComponent(BaseComponent):
         Returns:
             The stateful component or None if the component should not be memoized.
         """
-        from reflex.components.core.foreach import Foreach
-
         if component._memoization_mode.disposition == MemoizationDisposition.NEVER:
             # Never memoize this component.
             return None
@@ -2004,10 +2001,6 @@ class StatefulComponent(BaseComponent):
                 # Skip BaseComponent and StatefulComponent children.
                 if not isinstance(child, Component):
                     continue
-                # Always consider Foreach something that must be memoized by the parent.
-                if isinstance(child, Foreach):
-                    should_memoize = True
-                    break
                 child = cls._child_var(child)
                 if isinstance(child, Var) and child._get_all_var_data():
                     should_memoize = True
@@ -2057,12 +2050,9 @@ class StatefulComponent(BaseComponent):
             The Var from the child component or the child itself (for regular cases).
         """
         from reflex.components.base.bare import Bare
-        from reflex.components.core.foreach import Foreach
 
         if isinstance(child, Bare):
             return child.contents
-        if isinstance(child, Foreach):
-            return child.iterable
         return child
 
     @classmethod

+ 0 - 1
reflex/components/core/__init__.py

@@ -25,7 +25,6 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
     "debounce": ["DebounceInput", "debounce_input"],
     "foreach": [
         "foreach",
-        "Foreach",
     ],
     "html": ["html", "Html"],
     "match": [

+ 0 - 1
reflex/components/core/__init__.pyi

@@ -21,7 +21,6 @@ from .cond import color_mode_cond as color_mode_cond
 from .cond import cond as cond
 from .debounce import DebounceInput as DebounceInput
 from .debounce import debounce_input as debounce_input
-from .foreach import Foreach as Foreach
 from .foreach import foreach as foreach
 from .html import Html as Html
 from .html import html as html

+ 30 - 118
reflex/components/core/foreach.py

@@ -2,15 +2,11 @@
 
 from __future__ import annotations
 
-import inspect
-from typing import Any, Callable, Iterable
+from typing import Callable, Iterable
 
-from reflex.components.base.fragment import Fragment
-from reflex.components.component import Component
-from reflex.components.tags import IterTag
-from reflex.constants import MemoizationMode
-from reflex.state import ComponentState
 from reflex.vars.base import LiteralVar, Var
+from reflex.vars.object import ObjectVar
+from reflex.vars.sequence import ArrayVar
 
 
 class ForeachVarError(TypeError):
@@ -21,116 +17,32 @@ class ForeachRenderError(TypeError):
     """Raised when there is an error with the foreach render function."""
 
 
-class Foreach(Component):
-    """A component that takes in an iterable and a render function and renders a list of components."""
-
-    _memoization_mode = MemoizationMode(recursive=False)
-
-    # The iterable to create components from.
-    iterable: Var[Iterable]
-
-    # A function from the render args to the component.
-    render_fn: Callable = Fragment.create
-
-    @classmethod
-    def create(
-        cls,
-        iterable: Var[Iterable] | Iterable,
-        render_fn: Callable,
-    ) -> Foreach:
-        """Create a foreach component.
-
-        Args:
-            iterable: The iterable to create components from.
-            render_fn: A function from the render args to the component.
-
-        Returns:
-            The foreach component.
-
-        Raises:
-            ForeachVarError: If the iterable is of type Any.
-            TypeError: If the render function is a ComponentState.
-        """
-        iterable = LiteralVar.create(iterable)
-        if iterable._var_type == Any:
-            raise ForeachVarError(
-                f"Could not foreach over var `{iterable!s}` of type Any. "
-                "(If you are trying to foreach over a state var, add a type annotation to the var). "
-                "See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
-            )
-
-        if (
-            hasattr(render_fn, "__qualname__")
-            and render_fn.__qualname__ == ComponentState.create.__qualname__
-        ):
-            raise TypeError(
-                "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet."
-            )
-
-        component = cls(
-            iterable=iterable,
-            render_fn=render_fn,
+def foreach(
+    iterable: Var[Iterable] | Iterable,
+    render_fn: Callable,
+) -> Var:
+    """Create a foreach component.
+
+    Args:
+        iterable: The iterable to create components from.
+        render_fn: A function from the render args to the component.
+
+    Returns:
+        The foreach component.
+
+    Raises:
+        ForeachVarError: If the iterable is of type Any.
+        TypeError: If the render function is a ComponentState.
+    """
+    iterable = LiteralVar.create(iterable)
+    if isinstance(iterable, ObjectVar):
+        iterable = iterable.items()
+
+    if not isinstance(iterable, ArrayVar):
+        raise ForeachVarError(
+            f"Could not foreach over var `{iterable!s}` of type {iterable._var_type!s}. "
+            "(If you are trying to foreach over a state var, add a type annotation to the var). "
+            "See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
         )
-        # Keep a ref to a rendered component to determine correct imports/hooks/styles.
-        component.children = [component._render().render_component()]
-        return component
-
-    def _render(self) -> IterTag:
-        props = {}
-
-        render_sig = inspect.signature(self.render_fn)
-        params = list(render_sig.parameters.values())
-
-        # Validate the render function signature.
-        if len(params) == 0 or len(params) > 2:
-            raise ForeachRenderError(
-                "Expected 1 or 2 parameters in foreach render function, got "
-                f"{[p.name for p in params]}. See "
-                "https://reflex.dev/docs/library/dynamic-rendering/foreach/"
-            )
-
-        if len(params) >= 1:
-            # Determine the arg var name based on the params accepted by render_fn.
-            props["arg_var_name"] = params[0].name
-
-        if len(params) == 2:
-            # Determine the index var name based on the params accepted by render_fn.
-            props["index_var_name"] = params[1].name
-        else:
-            # Otherwise, use a deterministic index, based on the render function bytecode.
-            code_hash = (
-                hash(self.render_fn.__code__)
-                .to_bytes(
-                    length=8,
-                    byteorder="big",
-                    signed=True,
-                )
-                .hex()
-            )
-            props["index_var_name"] = f"index_{code_hash}"
-
-        return IterTag(
-            iterable=self.iterable,
-            render_fn=self.render_fn,
-            children=self.children,
-            **props,
-        )
-
-    def render(self):
-        """Render the component.
-
-        Returns:
-            The dictionary for template of component.
-        """
-        tag = self._render()
-
-        return dict(
-            tag,
-            iterable_state=str(tag.iterable),
-            arg_name=tag.arg_var_name,
-            arg_index=tag.get_index_var_arg(),
-            iterable_type=tag.iterable._var_type.mro()[0].__name__,
-        )
-
 
-foreach = Foreach.create
+    return iterable.foreach(render_fn)

+ 1 - 1
reflex/components/radix/primitives/slider.py

@@ -188,7 +188,7 @@ class Slider(ComponentNamespace):
         else:
             children = [
                 track,
-                #     Foreach.create(props.get("value"), lambda e: SliderThumb.create()),  # foreach doesn't render Thumbs properly # noqa: ERA001
+                #     foreach(props.get("value"), lambda e: SliderThumb.create()),  # foreach doesn't render Thumbs properly # noqa: ERA001
             ]
 
         return SliderRoot.create(*children, **props)

+ 2 - 2
reflex/components/radix/themes/layout/list.py

@@ -5,7 +5,7 @@ from __future__ import annotations
 from typing import Any, Iterable, Literal, Union
 
 from reflex.components.component import Component, ComponentNamespace
-from reflex.components.core.foreach import Foreach
+from reflex.components.core.foreach import foreach
 from reflex.components.el.elements.typography import Li, Ol, Ul
 from reflex.components.lucide.icon import Icon
 from reflex.components.markdown.markdown import MarkdownComponentMap
@@ -70,7 +70,7 @@ class BaseList(Component, MarkdownComponentMap):
 
         if not children and items is not None:
             if isinstance(items, Var):
-                children = [Foreach.create(items, ListItem.create)]
+                children = [foreach(items, ListItem.create)]
             else:
                 children = [ListItem.create(item) for item in items]  # type: ignore
         props["direction"] = "column"

+ 0 - 1
reflex/components/tags/__init__.py

@@ -1,4 +1,3 @@
 """Representations for React tags."""
 
-from .iter_tag import IterTag
 from .tag import Tag

+ 0 - 141
reflex/components/tags/iter_tag.py

@@ -1,141 +0,0 @@
-"""Tag to loop through a list of components."""
-
-from __future__ import annotations
-
-import dataclasses
-import inspect
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args
-
-from reflex.components.tags.tag import Tag
-from reflex.utils import types
-from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
-
-if TYPE_CHECKING:
-    from reflex.components.component import Component
-
-
-@dataclasses.dataclass()
-class IterTag(Tag):
-    """An iterator tag."""
-
-    # The var to iterate over.
-    iterable: Var[Iterable] = dataclasses.field(
-        default_factory=lambda: LiteralArrayVar.create([])
-    )
-
-    # The component render function for each item in the iterable.
-    render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x)
-
-    # The name of the arg var.
-    arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
-
-    # The name of the index var.
-    index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
-
-    def get_iterable_var_type(self) -> types.GenericType:
-        """Get the type of the iterable var.
-
-        Returns:
-            The type of the iterable var.
-        """
-        iterable = self.iterable
-        try:
-            if iterable._var_type.mro()[0] is dict:
-                # Arg is a tuple of (key, value).
-                return Tuple[get_args(iterable._var_type)]
-            elif iterable._var_type.mro()[0] is tuple:
-                # Arg is a union of any possible values in the tuple.
-                return Union[get_args(iterable._var_type)]
-            else:
-                return get_args(iterable._var_type)[0]
-        except Exception:
-            return Any
-
-    def get_index_var(self) -> Var:
-        """Get the index var for the tag (with curly braces).
-
-        This is used to reference the index var within the tag.
-
-        Returns:
-            The index var.
-        """
-        return Var(
-            _js_expr=self.index_var_name,
-            _var_type=int,
-        ).guess_type()
-
-    def get_arg_var(self) -> Var:
-        """Get the arg var for the tag (with curly braces).
-
-        This is used to reference the arg var within the tag.
-
-        Returns:
-            The arg var.
-        """
-        return Var(
-            _js_expr=self.arg_var_name,
-            _var_type=self.get_iterable_var_type(),
-        ).guess_type()
-
-    def get_index_var_arg(self) -> Var:
-        """Get the index var for the tag (without curly braces).
-
-        This is used to render the index var in the .map() function.
-
-        Returns:
-            The index var.
-        """
-        return Var(
-            _js_expr=self.index_var_name,
-            _var_type=int,
-        ).guess_type()
-
-    def get_arg_var_arg(self) -> Var:
-        """Get the arg var for the tag (without curly braces).
-
-        This is used to render the arg var in the .map() function.
-
-        Returns:
-            The arg var.
-        """
-        return Var(
-            _js_expr=self.arg_var_name,
-            _var_type=self.get_iterable_var_type(),
-        ).guess_type()
-
-    def render_component(self) -> Component:
-        """Render the component.
-
-        Raises:
-            ValueError: If the render function takes more than 2 arguments.
-
-        Returns:
-            The rendered component.
-        """
-        # Import here to avoid circular imports.
-        from reflex.components.base.fragment import Fragment
-        from reflex.components.core.foreach import Foreach
-
-        # Get the render function arguments.
-        args = inspect.getfullargspec(self.render_fn).args
-        arg = self.get_arg_var()
-        index = self.get_index_var()
-
-        if len(args) == 1:
-            # If the render function doesn't take the index as an argument.
-            component = self.render_fn(arg)
-        else:
-            # If the render function takes the index as an argument.
-            if len(args) != 2:
-                raise ValueError("The render function must take 2 arguments.")
-            component = self.render_fn(arg, index)
-
-        # Nested foreach components or cond must be wrapped in fragments.
-        if isinstance(component, (Foreach, Var)):
-            component = Fragment.create(component)
-
-        # Set the component key.
-        if component.key is None:
-            component.key = index
-
-        return component

+ 11 - 0
reflex/utils/types.py

@@ -890,12 +890,23 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
     Returns:
         Whether the type hint is a subclass of the other type hint.
     """
+    if isinstance(possible_subclass, Sequence) and isinstance(
+        possible_superclass, Sequence
+    ):
+        return all(
+            typehint_issubclass(subclass, superclass)
+            for subclass, superclass in zip(possible_subclass, possible_superclass)
+        )
     if possible_subclass is possible_superclass:
         return True
     if possible_superclass is Any:
         return True
     if possible_subclass is Any:
         return False
+    if isinstance(
+        possible_subclass, (TypeVar, typing_extensions.TypeVar)
+    ) or isinstance(possible_superclass, (TypeVar, typing_extensions.TypeVar)):
+        return True
 
     provided_type_origin = get_origin(possible_subclass)
     accepted_type_origin = get_origin(possible_superclass)

+ 41 - 0
reflex/vars/base.py

@@ -151,6 +151,28 @@ def unwrap_reflex_callalbe(
     return args
 
 
+_VAR_CACHING = True
+
+
+def get_var_caching() -> bool:
+    """Get the var caching status.
+
+    Returns:
+        The var caching status.
+    """
+    return _VAR_CACHING
+
+
+def set_var_caching(value: bool):
+    """Set the var caching status.
+
+    Args:
+        value: The value to set the var caching status to.
+    """
+    global _VAR_CACHING
+    _VAR_CACHING = value
+
+
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
@@ -1186,6 +1208,25 @@ class Var(Generic[VAR_TYPE]):
         """
         return self
 
+    def __getattribute__(self, name: str) -> Any:
+        """Get an attribute of the var.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The attribute.
+        """
+        if not _VAR_CACHING:
+            try:
+                self_dict = object.__getattribute__(self, "__dict__")
+                for key in self_dict:
+                    if key.startswith("_cached_"):
+                        del self_dict[key]
+            except Exception:
+                pass
+        return super().__getattribute__(name)
+
     def __getattr__(self, name: str):
         """Get an attribute of the var.
 

+ 55 - 9
reflex/vars/sequence.py

@@ -741,7 +741,8 @@ if TYPE_CHECKING:
 def map_array_operation(
     array: Var[Sequence[INNER_ARRAY_VAR]],
     function: Var[
-        ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]
+        ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR]
+        | ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]
         | ReflexCallable[[], ANOTHER_ARRAY_VAR]
     ],
 ) -> CustomVarOperationReturn[Sequence[ANOTHER_ARRAY_VAR]]:
@@ -973,7 +974,8 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)):
 
     def foreach(
         self: ArrayVar[Sequence[INNER_ARRAY_VAR]],
-        fn: Callable[[Var[INNER_ARRAY_VAR]], ANOTHER_ARRAY_VAR]
+        fn: Callable[[Var[INNER_ARRAY_VAR], NumberVar[int]], ANOTHER_ARRAY_VAR]
+        | Callable[[Var[INNER_ARRAY_VAR]], ANOTHER_ARRAY_VAR]
         | Callable[[], ANOTHER_ARRAY_VAR],
     ) -> ArrayVar[Sequence[ANOTHER_ARRAY_VAR]]:
         """Apply a function to each element of the array.
@@ -987,21 +989,36 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)):
         Raises:
             VarTypeError: If the function takes more than one argument.
         """
+        from reflex.state import ComponentState
+
         from .function import ArgsFunctionOperation
 
         if not callable(fn):
             raise_unsupported_operand_types("foreach", (type(self), type(fn)))
         # get the number of arguments of the function
         num_args = len(inspect.signature(fn).parameters)
-        if num_args > 1:
+        if num_args > 2:
             raise VarTypeError(
-                "The function passed to foreach should take at most one argument."
+                "The function passed to foreach should take at most two arguments."
+            )
+
+        if (
+            hasattr(fn, "__qualname__")
+            and fn.__qualname__ == ComponentState.create.__qualname__
+        ):
+            raise TypeError(
+                "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet."
             )
 
         if num_args == 0:
-            return_value = fn()  # type: ignore
+            fn_result = fn()  # pyright: ignore [reportCallIssue]
+            return_value = Var.create(fn_result)
             simple_function_var: FunctionVar[ReflexCallable[[], ANOTHER_ARRAY_VAR]] = (
-                ArgsFunctionOperation.create((), return_value)
+                ArgsFunctionOperation.create(
+                    (),
+                    return_value,
+                    _var_type=ReflexCallable[[], return_value._var_type],
+                )
             )
             return map_array_operation(self, simple_function_var).guess_type()
 
@@ -1021,11 +1038,40 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)):
             ).guess_type(),
         )
 
+        if num_args == 1:
+            fn_result = fn(first_arg)  # pyright: ignore [reportCallIssue]
+
+            return_value = Var.create(fn_result)
+
+            function_var = cast(
+                Var[ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]],
+                ArgsFunctionOperation.create(
+                    (arg_name,),
+                    return_value,
+                    _var_type=ReflexCallable[[first_arg_type], return_value._var_type],
+                ),
+            )
+
+            return map_array_operation.call(self, function_var).guess_type()
+
+        second_arg = cast(
+            NumberVar[int],
+            Var(
+                _js_expr=get_unique_variable_name(),
+                _var_type=int,
+            ).guess_type(),
+        )
+
+        fn_result = fn(first_arg, second_arg)  # pyright: ignore [reportCallIssue]
+
+        return_value = Var.create(fn_result)
+
         function_var = cast(
-            Var[ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]],
+            Var[ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR]],
             ArgsFunctionOperation.create(
-                (arg_name,),
-                Var.create(fn(first_arg)),  # type: ignore
+                (arg_name, second_arg._js_expr),
+                return_value,
+                _var_type=ReflexCallable[[first_arg_type, int], return_value._var_type],
             ),
         )
 

+ 22 - 135
tests/units/components/core/test_foreach.py

@@ -6,16 +6,11 @@ import pytest
 from reflex import el
 from reflex.base import Base
 from reflex.components.component import Component
-from reflex.components.core.foreach import (
-    Foreach,
-    ForeachRenderError,
-    ForeachVarError,
-    foreach,
-)
+from reflex.components.core.foreach import ForeachVarError, foreach
 from reflex.components.radix.themes.layout.box import box
 from reflex.components.radix.themes.typography.text import text
 from reflex.state import BaseState, ComponentState
-from reflex.vars.base import Var
+from reflex.utils.exceptions import VarTypeError
 from reflex.vars.number import NumberVar
 from reflex.vars.sequence import ArrayVar
 
@@ -141,143 +136,35 @@ def display_color_index_tuple(color):
 seen_index_vars = set()
 
 
-@pytest.mark.parametrize(
-    "state_var, render_fn, render_dict",
-    [
-        (
-            ForEachState.colors_list,
-            display_color,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.colors_list",
-                "iterable_type": "list",
-            },
-        ),
-        (
-            ForEachState.colors_dict_list,
-            display_color_name,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list",
-                "iterable_type": "list",
-            },
-        ),
-        (
-            ForEachState.colors_nested_dict_list,
-            display_shade,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list",
-                "iterable_type": "list",
-            },
-        ),
-        (
-            ForEachState.primary_color,
-            display_primary_colors,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.primary_color",
-                "iterable_type": "dict",
-            },
-        ),
-        (
-            ForEachState.color_with_shades,
-            display_color_with_shades,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.color_with_shades",
-                "iterable_type": "dict",
-            },
-        ),
-        (
-            ForEachState.nested_colors_with_shades,
-            display_nested_color_with_shades,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades",
-                "iterable_type": "dict",
-            },
-        ),
-        (
-            ForEachState.nested_colors_with_shades,
-            display_nested_color_with_shades_v2,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades",
-                "iterable_type": "dict",
-            },
-        ),
-        (
-            ForEachState.color_tuple,
-            display_color_tuple,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.color_tuple",
-                "iterable_type": "tuple",
-            },
-        ),
-        (
-            ForEachState.colors_set,
-            display_colors_set,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.colors_set",
-                "iterable_type": "set",
-            },
-        ),
-        (
-            ForEachState.nested_colors_list,
-            lambda el, i: display_nested_list_element(el, i),
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list",
-                "iterable_type": "list",
-            },
-        ),
-        (
-            ForEachState.color_index_tuple,
-            display_color_index_tuple,
-            {
-                "iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple",
-                "iterable_type": "tuple",
-            },
-        ),
-    ],
-)
-def test_foreach_render(state_var, render_fn, render_dict):
-    """Test that the foreach component renders without error.
-
-    Args:
-        state_var: the state var.
-        render_fn: The render callable
-        render_dict: return dict on calling `component.render`
-    """
-    component = Foreach.create(state_var, render_fn)
-
-    rend = component.render()
-    assert rend["iterable_state"] == render_dict["iterable_state"]
-    assert rend["iterable_type"] == render_dict["iterable_type"]
-
-    # Make sure the index vars are unique.
-    arg_index = rend["arg_index"]
-    assert isinstance(arg_index, Var)
-    assert arg_index._js_expr not in seen_index_vars
-    assert arg_index._var_type is int
-    seen_index_vars.add(arg_index._js_expr)
-
-
 def test_foreach_bad_annotations():
     """Test that the foreach component raises a ForeachVarError if the iterable is of type Any."""
     with pytest.raises(ForeachVarError):
-        Foreach.create(
+        foreach(
             ForEachState.bad_annotation_list,
-            lambda sublist: Foreach.create(sublist, lambda color: text(color)),
+            lambda sublist: foreach(sublist, lambda color: text(color)),
         )
 
 
 def test_foreach_no_param_in_signature():
-    """Test that the foreach component raises a ForeachRenderError if no parameters are passed."""
-    with pytest.raises(ForeachRenderError):
-        Foreach.create(
-            ForEachState.colors_list,
-            lambda: text("color"),
-        )
+    """Test that the foreach component DOES NOT raise an error if no parameters are passed."""
+    foreach(
+        ForEachState.colors_list,
+        lambda: text("color"),
+    )
+
+
+def test_foreach_with_index():
+    """Test that the foreach component works with an index."""
+    foreach(
+        ForEachState.colors_list,
+        lambda color, index: text(color, index),
+    )
 
 
 def test_foreach_too_many_params_in_signature():
     """Test that the foreach component raises a ForeachRenderError if too many parameters are passed."""
-    with pytest.raises(ForeachRenderError):
-        Foreach.create(
+    with pytest.raises(VarTypeError):
+        foreach(
             ForEachState.colors_list,
             lambda color, index, extra: text(color),
         )
@@ -292,13 +179,13 @@ def test_foreach_component_styles():
         )
     )
     component._add_style_recursive({box: {"color": "red"}})
-    assert 'css={({ ["color"] : "red" })}' in str(component)
+    assert '{ ["css"] : ({ ["color"] : "red" }) }' in str(component)
 
 
 def test_foreach_component_state():
     """Test that using a component state to render in the foreach raises an error."""
     with pytest.raises(TypeError):
-        Foreach.create(
+        foreach(
             ForEachState.colors_list,
             ComponentStateTest.create,
         )
@@ -306,7 +193,7 @@ def test_foreach_component_state():
 
 def test_foreach_default_factory():
     """Test that the default factory is called."""
-    _ = Foreach.create(
+    _ = foreach(
         ForEachState.default_factory_list,
         lambda tag: text(tag.name),
     )

+ 2 - 6
tests/units/components/test_component.py

@@ -1446,7 +1446,6 @@ def test_instantiate_all_components():
     untested_components = {
         "Card",
         "DebounceInput",
-        "Foreach",
         "FormControl",
         "Html",
         "Icon",
@@ -2147,14 +2146,11 @@ def test_add_style_foreach():
     page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i)))
     page._add_style_recursive(Style())
 
-    # Expect only a single child of the foreach on the python side
-    assert len(page.children[0].children) == 1
-
     # Expect the style to be added to the child of the foreach
-    assert 'css={({ ["color"] : "red" })}' in str(page.children[0].children[0])
+    assert '({ ["css"] : ({ ["color"] : "red" }) }),' in str(page.children[0])
 
     # Expect only one instance of this CSS dict in the rendered page
-    assert str(page).count('css={({ ["color"] : "red" })}') == 1
+    assert str(page).count('({ ["css"] : ({ ["color"] : "red" }) }),') == 1
 
 
 class TriggerState(rx.State):