Jelajahi Sumber

relax foreach to handle optional (#4901)

* relax foreach to handle optional

* simplify get index
Khaleel Al-Adhami 2 bulan lalu
induk
melakukan
3a6f7475e8

+ 8 - 7
reflex/components/core/cond.py

@@ -9,6 +9,7 @@ from reflex.components.component import BaseComponent, Component, MemoizationLea
 from reflex.components.tags import CondTag, Tag
 from reflex.constants import Dirs
 from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
+from reflex.utils import types
 from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
@@ -145,20 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
     if c2 is None:
         raise ValueError("For conditional vars, the second argument must be set.")
 
-    def create_var(cond_part: Any) -> Var[Any]:
-        return LiteralVar.create(cond_part)
-
     # convert the truth and false cond parts into vars so the _var_data can be obtained.
-    c1 = create_var(c1)
-    c2 = create_var(c2)
+    c1_var = Var.create(c1)
+    c2_var = Var.create(c2)
+
+    if condition is c1_var:
+        c1_var = c1_var.to(types.value_inside_optional(c1_var._var_type))
 
     # Create the conditional var.
     return ternary_operation(
         cond_var.bool()._replace(
             merge_var_data=VarData(imports=_IS_TRUE_IMPORT),
         ),
-        c1,
-        c2,
+        c1_var,
+        c2_var,
     )
 
 

+ 5 - 1
reflex/components/core/foreach.py

@@ -8,9 +8,11 @@ from typing import Any, Callable, Iterable
 
 from reflex.components.base.fragment import Fragment
 from reflex.components.component import Component
+from reflex.components.core.cond import cond
 from reflex.components.tags import IterTag
 from reflex.constants import MemoizationMode
 from reflex.state import ComponentState
+from reflex.utils import types
 from reflex.utils.exceptions import UntypedVarError
 from reflex.vars.base import LiteralVar, Var
 
@@ -85,6 +87,9 @@ class Foreach(Component):
                 "See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
             )
 
+        if types.is_optional(iterable._var_type):
+            iterable = cond(iterable, iterable, [])
+
         component = cls(
             iterable=iterable,
             render_fn=render_fn,
@@ -164,7 +169,6 @@ class Foreach(Component):
             iterable_state=str(tag.iterable),
             arg_name=tag.arg_var_name,
             arg_index=tag.get_index_var_arg(),
-            iterable_type=tag.iterable._var_type.mro()[0].__name__,
         )
 
 

+ 5 - 14
reflex/components/tags/iter_tag.py

@@ -4,10 +4,12 @@ from __future__ import annotations
 
 import dataclasses
 import inspect
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Type, Union, get_args
+from typing import TYPE_CHECKING, Callable, Iterable
 
 from reflex.components.tags.tag import Tag
+from reflex.utils.types import GenericType
 from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
+from reflex.vars.sequence import _determine_value_of_array_index
 
 if TYPE_CHECKING:
     from reflex.components.component import Component
@@ -31,24 +33,13 @@ class IterTag(Tag):
     # The name of the index var.
     index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
 
-    def get_iterable_var_type(self) -> Type:
+    def get_iterable_var_type(self) -> GenericType:
         """Get the type of the iterable var.
 
         Returns:
             The type of the iterable var.
         """
-        iterable = self.iterable
-        try:
-            if iterable._var_type.mro()[0] is dict:
-                # Arg is a tuple of (key, value).
-                return tuple[get_args(iterable._var_type)]  # pyright: ignore [reportReturnType]
-            elif iterable._var_type.mro()[0] is tuple:
-                # Arg is a union of any possible values in the tuple.
-                return Union[get_args(iterable._var_type)]  # pyright: ignore [reportReturnType]
-            else:
-                return get_args(iterable._var_type)[0]
-        except Exception:
-            return Any  # pyright: ignore [reportReturnType]
+        return _determine_value_of_array_index(self.iterable._var_type)
 
     def get_index_var(self) -> Var:
         """Get the index var for the tag (with curly braces).

+ 47 - 17
reflex/vars/base.py

@@ -1598,7 +1598,14 @@ def var_operation(  # pyright: ignore [reportOverlappingOverload]
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[bool]],
+    func: Callable[P, CustomVarOperationReturn[None]],
+) -> Callable[P, NoneVar]: ...
+
+
+@overload
+def var_operation(  # pyright: ignore [reportOverlappingOverload]
+    func: Callable[P, CustomVarOperationReturn[bool]]
+    | Callable[P, CustomVarOperationReturn[bool | None]],
 ) -> Callable[P, BooleanVar]: ...
 
 
@@ -1607,13 +1614,15 @@ NUMBER_T = TypeVar("NUMBER_T", int, float, int | float)
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[NUMBER_T]],
+    func: Callable[P, CustomVarOperationReturn[NUMBER_T]]
+    | Callable[P, CustomVarOperationReturn[NUMBER_T | None]],
 ) -> Callable[P, NumberVar[NUMBER_T]]: ...
 
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[str]],
+    func: Callable[P, CustomVarOperationReturn[str]]
+    | Callable[P, CustomVarOperationReturn[str | None]],
 ) -> Callable[P, StringVar]: ...
 
 
@@ -1622,7 +1631,8 @@ LIST_T = TypeVar("LIST_T", bound=Sequence)
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[LIST_T]],
+    func: Callable[P, CustomVarOperationReturn[LIST_T]]
+    | Callable[P, CustomVarOperationReturn[LIST_T | None]],
 ) -> Callable[P, ArrayVar[LIST_T]]: ...
 
 
@@ -1631,13 +1641,15 @@ OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]],
+    func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]]
+    | Callable[P, CustomVarOperationReturn[OBJECT_TYPE | None]],
 ) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...
 
 
 @overload
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[T]],
+    func: Callable[P, CustomVarOperationReturn[T]]
+    | Callable[P, CustomVarOperationReturn[T | None]],
 ) -> Callable[P, Var[T]]: ...
 
 
@@ -3278,53 +3290,71 @@ class Field(Generic[FIELD_TYPE]):
         """
 
     @overload
-    def __get__(self: Field[bool], instance: None, owner: Any) -> BooleanVar: ...
+    def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ...
 
     @overload
     def __get__(
-        self: Field[int] | Field[float] | Field[int | float], instance: None, owner: Any
-    ) -> NumberVar: ...
+        self: Field[bool] | Field[bool | None], instance: None, owner: Any
+    ) -> BooleanVar: ...
 
     @overload
-    def __get__(self: Field[str], instance: None, owner: Any) -> StringVar: ...
+    def __get__(
+        self: Field[int]
+        | Field[float]
+        | Field[int | float]
+        | Field[int | None]
+        | Field[float | None]
+        | Field[int | float | None],
+        instance: None,
+        owner: Any,
+    ) -> NumberVar: ...
 
     @overload
-    def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ...
+    def __get__(
+        self: Field[str] | Field[str | None], instance: None, owner: Any
+    ) -> StringVar: ...
 
     @overload
     def __get__(
-        self: Field[list[V]] | Field[set[V]],
+        self: Field[list[V]]
+        | Field[set[V]]
+        | Field[list[V] | None]
+        | Field[set[V] | None],
         instance: None,
         owner: Any,
     ) -> ArrayVar[Sequence[V]]: ...
 
     @overload
     def __get__(
-        self: Field[SEQUENCE_TYPE],
+        self: Field[SEQUENCE_TYPE] | Field[SEQUENCE_TYPE | None],
         instance: None,
         owner: Any,
     ) -> ArrayVar[SEQUENCE_TYPE]: ...
 
     @overload
     def __get__(
-        self: Field[MAPPING_TYPE], instance: None, owner: Any
+        self: Field[MAPPING_TYPE] | Field[MAPPING_TYPE | None],
+        instance: None,
+        owner: Any,
     ) -> ObjectVar[MAPPING_TYPE]: ...
 
     @overload
     def __get__(
-        self: Field[BASE_TYPE], instance: None, owner: Any
+        self: Field[BASE_TYPE] | Field[BASE_TYPE | None], instance: None, owner: Any
     ) -> ObjectVar[BASE_TYPE]: ...
 
     @overload
     def __get__(
-        self: Field[SQLA_TYPE], instance: None, owner: Any
+        self: Field[SQLA_TYPE] | Field[SQLA_TYPE | None], instance: None, owner: Any
     ) -> ObjectVar[SQLA_TYPE]: ...
 
     if TYPE_CHECKING:
 
         @overload
         def __get__(
-            self: Field[DATACLASS_TYPE], instance: None, owner: Any
+            self: Field[DATACLASS_TYPE] | Field[DATACLASS_TYPE | None],
+            instance: None,
+            owner: Any,
         ) -> ObjectVar[DATACLASS_TYPE]: ...
 
     @overload

+ 21 - 6
reflex/vars/object.py

@@ -441,9 +441,14 @@ def object_keys_operation(value: ObjectVar):
     Returns:
         The keys of the object.
     """
+    if not types.is_optional(value._var_type):
+        return var_operation_return(
+            js_expression=f"Object.keys({value})",
+            var_type=list[str],
+        )
     return var_operation_return(
-        js_expression=f"Object.keys({value})",
-        var_type=list[str],
+        js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.keys(value))({value})",
+        var_type=(list[str] | None),
     )
 
 
@@ -457,9 +462,14 @@ def object_values_operation(value: ObjectVar):
     Returns:
         The values of the object.
     """
+    if not types.is_optional(value._var_type):
+        return var_operation_return(
+            js_expression=f"Object.values({value})",
+            var_type=list[value._value_type()],
+        )
     return var_operation_return(
-        js_expression=f"Object.values({value})",
-        var_type=list[value._value_type()],
+        js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.values(value))({value})",
+        var_type=(list[value._value_type()] | None),
     )
 
 
@@ -473,9 +483,14 @@ def object_entries_operation(value: ObjectVar):
     Returns:
         The entries of the object.
     """
+    if not types.is_optional(value._var_type):
+        return var_operation_return(
+            js_expression=f"Object.entries({value})",
+            var_type=list[tuple[str, value._value_type()]],
+        )
     return var_operation_return(
-        js_expression=f"Object.entries({value})",
-        var_type=list[tuple[str, value._value_type()]],
+        js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.entries(value))({value})",
+        var_type=(list[tuple[str, value._value_type()]] | None),
     )
 
 

+ 20 - 0
tests/integration/test_var_operations.py

@@ -33,6 +33,10 @@ def VarOperations():
         list2: rx.Field[list] = rx.field([3, 4])
         list3: rx.Field[list] = rx.field(["first", "second", "third"])
         list4: rx.Field[list] = rx.field([Object(name="obj_1"), Object(name="obj_2")])
+        optional_list: rx.Field[list | None] = rx.field(None)
+        optional_dict: rx.Field[dict[str, str] | None] = rx.field(None)
+        optional_list_value: rx.Field[list[str] | None] = rx.field(["red", "yellow"])
+        optional_dict_value: rx.Field[dict[str, str] | None] = rx.field({"name": "red"})
         str_var1: rx.Field[str] = rx.field("first")
         str_var2: rx.Field[str] = rx.field("second")
         str_var3: rx.Field[str] = rx.field("ThIrD")
@@ -645,6 +649,22 @@ def VarOperations():
                 ),
                 id="typed_dict_in_foreach",
             ),
+            rx.box(
+                rx.foreach(VarOperationState.optional_list, rx.text.span),
+                id="optional_list",
+            ),
+            rx.box(
+                rx.foreach(VarOperationState.optional_dict, rx.text.span),
+                id="optional_dict",
+            ),
+            rx.box(
+                rx.foreach(VarOperationState.optional_list_value, rx.text.span),
+                id="optional_list_value",
+            ),
+            rx.box(
+                rx.foreach(VarOperationState.optional_dict_value, rx.text.span),
+                id="optional_dict_value",
+            ),
         )
 
 

+ 29 - 12
tests/units/components/core/test_foreach.py

@@ -1,6 +1,7 @@
 import pydantic.v1
 import pytest
 
+import reflex as rx
 from reflex import el
 from reflex.base import Base
 from reflex.components.component import Component
@@ -54,6 +55,11 @@ class ForEachState(BaseState):
 
     default_factory_list: list[ForEachTag] = pydantic.v1.Field(default_factory=list)
 
+    optional_list: rx.Field[list[str] | None] = rx.field(None)
+    optional_list_value: rx.Field[list[str] | None] = rx.field(["red", "yellow"])
+    optional_dict: rx.Field[dict[str, str] | None] = rx.field(None)
+    optional_dict_value: rx.Field[dict[str, str] | None] = rx.field({"name": "red"})
+
 
 class ComponentStateTest(ComponentState):
     """A test component state."""
@@ -145,7 +151,6 @@ seen_index_vars = set()
             display_color,
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_list",
-                "iterable_type": "list",
             },
         ),
         (
@@ -153,7 +158,6 @@ seen_index_vars = set()
             display_color_name,
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list",
-                "iterable_type": "list",
             },
         ),
         (
@@ -161,7 +165,6 @@ seen_index_vars = set()
             display_shade,
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list",
-                "iterable_type": "list",
             },
         ),
         (
@@ -169,7 +172,6 @@ seen_index_vars = set()
             display_primary_colors,
             {
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.primary_color)",
-                "iterable_type": "list",
             },
         ),
         (
@@ -177,7 +179,6 @@ seen_index_vars = set()
             display_color_with_shades,
             {
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.color_with_shades)",
-                "iterable_type": "list",
             },
         ),
         (
@@ -185,7 +186,6 @@ seen_index_vars = set()
             display_nested_color_with_shades,
             {
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
-                "iterable_type": "list",
             },
         ),
         (
@@ -193,7 +193,6 @@ seen_index_vars = set()
             display_nested_color_with_shades_v2,
             {
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
-                "iterable_type": "list",
             },
         ),
         (
@@ -201,7 +200,6 @@ seen_index_vars = set()
             display_color_tuple,
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.color_tuple",
-                "iterable_type": "tuple",
             },
         ),
         (
@@ -209,7 +207,6 @@ seen_index_vars = set()
             display_colors_set,
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_set",
-                "iterable_type": "set",
             },
         ),
         (
@@ -217,7 +214,6 @@ seen_index_vars = set()
             lambda el, i: display_nested_list_element(el, i),
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list",
-                "iterable_type": "list",
             },
         ),
         (
@@ -225,7 +221,6 @@ seen_index_vars = set()
             display_color_index_tuple,
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple",
-                "iterable_type": "tuple",
             },
         ),
     ],
@@ -242,7 +237,6 @@ def test_foreach_render(state_var, render_fn, render_dict):
 
     rend = component.render()
     assert rend["iterable_state"] == render_dict["iterable_state"]
-    assert rend["iterable_type"] == render_dict["iterable_type"]
 
     # Make sure the index vars are unique.
     arg_index = rend["arg_index"]
@@ -306,3 +300,26 @@ def test_foreach_default_factory():
         ForEachState.default_factory_list,
         lambda tag: text(tag.name),
     )
+
+
+def test_optional_list():
+    """Test that the foreach component works with optional lists."""
+    Foreach.create(
+        ForEachState.optional_list,
+        lambda color: text(color),
+    )
+
+    Foreach.create(
+        ForEachState.optional_list_value,
+        lambda color: text(color),
+    )
+
+    Foreach.create(
+        ForEachState.optional_dict,
+        lambda color: text(color[0], color[1]),
+    )
+
+    Foreach.create(
+        ForEachState.optional_dict_value,
+        lambda color: text(color[0], color[1]),
+    )

+ 3 - 3
tests/units/vars/test_object.py

@@ -136,7 +136,7 @@ def test_typing() -> None:
     var = ObjectState.base
     _ = assert_type(var, ObjectVar[Base])
     optional_var = ObjectState.base_optional
-    _ = assert_type(optional_var, ObjectVar[Base | None])
+    _ = assert_type(optional_var, ObjectVar[Base])
     list_var = ObjectState.base_list
     _ = assert_type(list_var, ArrayVar[Sequence[Base]])
     list_var_0 = list_var[0]
@@ -146,7 +146,7 @@ def test_typing() -> None:
     var = ObjectState.sqlamodel
     _ = assert_type(var, ObjectVar[SqlaModel])
     optional_var = ObjectState.sqlamodel_optional
-    _ = assert_type(optional_var, ObjectVar[SqlaModel | None])
+    _ = assert_type(optional_var, ObjectVar[SqlaModel])
     list_var = ObjectState.base_list
     _ = assert_type(list_var, ArrayVar[Sequence[Base]])
     list_var_0 = list_var[0]
@@ -156,7 +156,7 @@ def test_typing() -> None:
     var = ObjectState.dataclass
     _ = assert_type(var, ObjectVar[Dataclass])
     optional_var = ObjectState.dataclass_optional
-    _ = assert_type(optional_var, ObjectVar[Dataclass | None])
+    _ = assert_type(optional_var, ObjectVar[Dataclass])
     list_var = ObjectState.base_list
     _ = assert_type(list_var, ArrayVar[Sequence[Base]])
     list_var_0 = list_var[0]