Преглед изворни кода

relax foreach to handle optional (#4901)

* relax foreach to handle optional

* simplify get index
Khaleel Al-Adhami пре 2 месеци
родитељ
комит
3a6f7475e8

+ 8 - 7
reflex/components/core/cond.py

@@ -9,6 +9,7 @@ from reflex.components.component import BaseComponent, Component, MemoizationLea
 from reflex.components.tags import CondTag, Tag
 from reflex.components.tags import CondTag, Tag
 from reflex.constants import Dirs
 from reflex.constants import Dirs
 from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
 from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
+from reflex.utils import types
 from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import VarData
 from reflex.vars import VarData
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.base import LiteralVar, Var
@@ -145,20 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
     if c2 is None:
     if c2 is None:
         raise ValueError("For conditional vars, the second argument must be set.")
         raise ValueError("For conditional vars, the second argument must be set.")
 
 
-    def create_var(cond_part: Any) -> Var[Any]:
-        return LiteralVar.create(cond_part)
-
     # convert the truth and false cond parts into vars so the _var_data can be obtained.
     # convert the truth and false cond parts into vars so the _var_data can be obtained.
-    c1 = create_var(c1)
-    c2 = create_var(c2)
+    c1_var = Var.create(c1)
+    c2_var = Var.create(c2)
+
+    if condition is c1_var:
+        c1_var = c1_var.to(types.value_inside_optional(c1_var._var_type))
 
 
     # Create the conditional var.
     # Create the conditional var.
     return ternary_operation(
     return ternary_operation(
         cond_var.bool()._replace(
         cond_var.bool()._replace(
             merge_var_data=VarData(imports=_IS_TRUE_IMPORT),
             merge_var_data=VarData(imports=_IS_TRUE_IMPORT),
         ),
         ),
-        c1,
-        c2,
+        c1_var,
+        c2_var,
     )
     )
 
 
 
 

+ 5 - 1
reflex/components/core/foreach.py

@@ -8,9 +8,11 @@ from typing import Any, Callable, Iterable
 
 
 from reflex.components.base.fragment import Fragment
 from reflex.components.base.fragment import Fragment
 from reflex.components.component import Component
 from reflex.components.component import Component
+from reflex.components.core.cond import cond
 from reflex.components.tags import IterTag
 from reflex.components.tags import IterTag
 from reflex.constants import MemoizationMode
 from reflex.constants import MemoizationMode
 from reflex.state import ComponentState
 from reflex.state import ComponentState
+from reflex.utils import types
 from reflex.utils.exceptions import UntypedVarError
 from reflex.utils.exceptions import UntypedVarError
 from reflex.vars.base import LiteralVar, Var
 from reflex.vars.base import LiteralVar, Var
 
 
@@ -85,6 +87,9 @@ class Foreach(Component):
                 "See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
                 "See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
             )
             )
 
 
+        if types.is_optional(iterable._var_type):
+            iterable = cond(iterable, iterable, [])
+
         component = cls(
         component = cls(
             iterable=iterable,
             iterable=iterable,
             render_fn=render_fn,
             render_fn=render_fn,
@@ -164,7 +169,6 @@ class Foreach(Component):
             iterable_state=str(tag.iterable),
             iterable_state=str(tag.iterable),
             arg_name=tag.arg_var_name,
             arg_name=tag.arg_var_name,
             arg_index=tag.get_index_var_arg(),
             arg_index=tag.get_index_var_arg(),
-            iterable_type=tag.iterable._var_type.mro()[0].__name__,
         )
         )
 
 
 
 

+ 5 - 14
reflex/components/tags/iter_tag.py

@@ -4,10 +4,12 @@ from __future__ import annotations
 
 
 import dataclasses
 import dataclasses
 import inspect
 import inspect
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Type, Union, get_args
+from typing import TYPE_CHECKING, Callable, Iterable
 
 
 from reflex.components.tags.tag import Tag
 from reflex.components.tags.tag import Tag
+from reflex.utils.types import GenericType
 from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
 from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
+from reflex.vars.sequence import _determine_value_of_array_index
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.components.component import Component
     from reflex.components.component import Component
@@ -31,24 +33,13 @@ class IterTag(Tag):
     # The name of the index var.
     # The name of the index var.
     index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
     index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
 
 
-    def get_iterable_var_type(self) -> Type:
+    def get_iterable_var_type(self) -> GenericType:
         """Get the type of the iterable var.
         """Get the type of the iterable var.
 
 
         Returns:
         Returns:
             The type of the iterable var.
             The type of the iterable var.
         """
         """
-        iterable = self.iterable
-        try:
-            if iterable._var_type.mro()[0] is dict:
-                # Arg is a tuple of (key, value).
-                return tuple[get_args(iterable._var_type)]  # pyright: ignore [reportReturnType]
-            elif iterable._var_type.mro()[0] is tuple:
-                # Arg is a union of any possible values in the tuple.
-                return Union[get_args(iterable._var_type)]  # pyright: ignore [reportReturnType]
-            else:
-                return get_args(iterable._var_type)[0]
-        except Exception:
-            return Any  # pyright: ignore [reportReturnType]
+        return _determine_value_of_array_index(self.iterable._var_type)
 
 
     def get_index_var(self) -> Var:
     def get_index_var(self) -> Var:
         """Get the index var for the tag (with curly braces).
         """Get the index var for the tag (with curly braces).

+ 47 - 17
reflex/vars/base.py

@@ -1598,7 +1598,14 @@ def var_operation(  # pyright: ignore [reportOverlappingOverload]
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[bool]],
+    func: Callable[P, CustomVarOperationReturn[None]],
+) -> Callable[P, NoneVar]: ...
+
+
+@overload
+def var_operation(  # pyright: ignore [reportOverlappingOverload]
+    func: Callable[P, CustomVarOperationReturn[bool]]
+    | Callable[P, CustomVarOperationReturn[bool | None]],
 ) -> Callable[P, BooleanVar]: ...
 ) -> Callable[P, BooleanVar]: ...
 
 
 
 
@@ -1607,13 +1614,15 @@ NUMBER_T = TypeVar("NUMBER_T", int, float, int | float)
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[NUMBER_T]],
+    func: Callable[P, CustomVarOperationReturn[NUMBER_T]]
+    | Callable[P, CustomVarOperationReturn[NUMBER_T | None]],
 ) -> Callable[P, NumberVar[NUMBER_T]]: ...
 ) -> Callable[P, NumberVar[NUMBER_T]]: ...
 
 
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[str]],
+    func: Callable[P, CustomVarOperationReturn[str]]
+    | Callable[P, CustomVarOperationReturn[str | None]],
 ) -> Callable[P, StringVar]: ...
 ) -> Callable[P, StringVar]: ...
 
 
 
 
@@ -1622,7 +1631,8 @@ LIST_T = TypeVar("LIST_T", bound=Sequence)
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[LIST_T]],
+    func: Callable[P, CustomVarOperationReturn[LIST_T]]
+    | Callable[P, CustomVarOperationReturn[LIST_T | None]],
 ) -> Callable[P, ArrayVar[LIST_T]]: ...
 ) -> Callable[P, ArrayVar[LIST_T]]: ...
 
 
 
 
@@ -1631,13 +1641,15 @@ OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]],
+    func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]]
+    | Callable[P, CustomVarOperationReturn[OBJECT_TYPE | None]],
 ) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...
 ) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...
 
 
 
 
 @overload
 @overload
 def var_operation(
 def var_operation(
-    func: Callable[P, CustomVarOperationReturn[T]],
+    func: Callable[P, CustomVarOperationReturn[T]]
+    | Callable[P, CustomVarOperationReturn[T | None]],
 ) -> Callable[P, Var[T]]: ...
 ) -> Callable[P, Var[T]]: ...
 
 
 
 
@@ -3278,53 +3290,71 @@ class Field(Generic[FIELD_TYPE]):
         """
         """
 
 
     @overload
     @overload
-    def __get__(self: Field[bool], instance: None, owner: Any) -> BooleanVar: ...
+    def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ...
 
 
     @overload
     @overload
     def __get__(
     def __get__(
-        self: Field[int] | Field[float] | Field[int | float], instance: None, owner: Any
-    ) -> NumberVar: ...
+        self: Field[bool] | Field[bool | None], instance: None, owner: Any
+    ) -> BooleanVar: ...
 
 
     @overload
     @overload
-    def __get__(self: Field[str], instance: None, owner: Any) -> StringVar: ...
+    def __get__(
+        self: Field[int]
+        | Field[float]
+        | Field[int | float]
+        | Field[int | None]
+        | Field[float | None]
+        | Field[int | float | None],
+        instance: None,
+        owner: Any,
+    ) -> NumberVar: ...
 
 
     @overload
     @overload
-    def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ...
+    def __get__(
+        self: Field[str] | Field[str | None], instance: None, owner: Any
+    ) -> StringVar: ...
 
 
     @overload
     @overload
     def __get__(
     def __get__(
-        self: Field[list[V]] | Field[set[V]],
+        self: Field[list[V]]
+        | Field[set[V]]
+        | Field[list[V] | None]
+        | Field[set[V] | None],
         instance: None,
         instance: None,
         owner: Any,
         owner: Any,
     ) -> ArrayVar[Sequence[V]]: ...
     ) -> ArrayVar[Sequence[V]]: ...
 
 
     @overload
     @overload
     def __get__(
     def __get__(
-        self: Field[SEQUENCE_TYPE],
+        self: Field[SEQUENCE_TYPE] | Field[SEQUENCE_TYPE | None],
         instance: None,
         instance: None,
         owner: Any,
         owner: Any,
     ) -> ArrayVar[SEQUENCE_TYPE]: ...
     ) -> ArrayVar[SEQUENCE_TYPE]: ...
 
 
     @overload
     @overload
     def __get__(
     def __get__(
-        self: Field[MAPPING_TYPE], instance: None, owner: Any
+        self: Field[MAPPING_TYPE] | Field[MAPPING_TYPE | None],
+        instance: None,
+        owner: Any,
     ) -> ObjectVar[MAPPING_TYPE]: ...
     ) -> ObjectVar[MAPPING_TYPE]: ...
 
 
     @overload
     @overload
     def __get__(
     def __get__(
-        self: Field[BASE_TYPE], instance: None, owner: Any
+        self: Field[BASE_TYPE] | Field[BASE_TYPE | None], instance: None, owner: Any
     ) -> ObjectVar[BASE_TYPE]: ...
     ) -> ObjectVar[BASE_TYPE]: ...
 
 
     @overload
     @overload
     def __get__(
     def __get__(
-        self: Field[SQLA_TYPE], instance: None, owner: Any
+        self: Field[SQLA_TYPE] | Field[SQLA_TYPE | None], instance: None, owner: Any
     ) -> ObjectVar[SQLA_TYPE]: ...
     ) -> ObjectVar[SQLA_TYPE]: ...
 
 
     if TYPE_CHECKING:
     if TYPE_CHECKING:
 
 
         @overload
         @overload
         def __get__(
         def __get__(
-            self: Field[DATACLASS_TYPE], instance: None, owner: Any
+            self: Field[DATACLASS_TYPE] | Field[DATACLASS_TYPE | None],
+            instance: None,
+            owner: Any,
         ) -> ObjectVar[DATACLASS_TYPE]: ...
         ) -> ObjectVar[DATACLASS_TYPE]: ...
 
 
     @overload
     @overload

+ 21 - 6
reflex/vars/object.py

@@ -441,9 +441,14 @@ def object_keys_operation(value: ObjectVar):
     Returns:
     Returns:
         The keys of the object.
         The keys of the object.
     """
     """
+    if not types.is_optional(value._var_type):
+        return var_operation_return(
+            js_expression=f"Object.keys({value})",
+            var_type=list[str],
+        )
     return var_operation_return(
     return var_operation_return(
-        js_expression=f"Object.keys({value})",
-        var_type=list[str],
+        js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.keys(value))({value})",
+        var_type=(list[str] | None),
     )
     )
 
 
 
 
@@ -457,9 +462,14 @@ def object_values_operation(value: ObjectVar):
     Returns:
     Returns:
         The values of the object.
         The values of the object.
     """
     """
+    if not types.is_optional(value._var_type):
+        return var_operation_return(
+            js_expression=f"Object.values({value})",
+            var_type=list[value._value_type()],
+        )
     return var_operation_return(
     return var_operation_return(
-        js_expression=f"Object.values({value})",
-        var_type=list[value._value_type()],
+        js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.values(value))({value})",
+        var_type=(list[value._value_type()] | None),
     )
     )
 
 
 
 
@@ -473,9 +483,14 @@ def object_entries_operation(value: ObjectVar):
     Returns:
     Returns:
         The entries of the object.
         The entries of the object.
     """
     """
+    if not types.is_optional(value._var_type):
+        return var_operation_return(
+            js_expression=f"Object.entries({value})",
+            var_type=list[tuple[str, value._value_type()]],
+        )
     return var_operation_return(
     return var_operation_return(
-        js_expression=f"Object.entries({value})",
-        var_type=list[tuple[str, value._value_type()]],
+        js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.entries(value))({value})",
+        var_type=(list[tuple[str, value._value_type()]] | None),
     )
     )
 
 
 
 

+ 20 - 0
tests/integration/test_var_operations.py

@@ -33,6 +33,10 @@ def VarOperations():
         list2: rx.Field[list] = rx.field([3, 4])
         list2: rx.Field[list] = rx.field([3, 4])
         list3: rx.Field[list] = rx.field(["first", "second", "third"])
         list3: rx.Field[list] = rx.field(["first", "second", "third"])
         list4: rx.Field[list] = rx.field([Object(name="obj_1"), Object(name="obj_2")])
         list4: rx.Field[list] = rx.field([Object(name="obj_1"), Object(name="obj_2")])
+        optional_list: rx.Field[list | None] = rx.field(None)
+        optional_dict: rx.Field[dict[str, str] | None] = rx.field(None)
+        optional_list_value: rx.Field[list[str] | None] = rx.field(["red", "yellow"])
+        optional_dict_value: rx.Field[dict[str, str] | None] = rx.field({"name": "red"})
         str_var1: rx.Field[str] = rx.field("first")
         str_var1: rx.Field[str] = rx.field("first")
         str_var2: rx.Field[str] = rx.field("second")
         str_var2: rx.Field[str] = rx.field("second")
         str_var3: rx.Field[str] = rx.field("ThIrD")
         str_var3: rx.Field[str] = rx.field("ThIrD")
@@ -645,6 +649,22 @@ def VarOperations():
                 ),
                 ),
                 id="typed_dict_in_foreach",
                 id="typed_dict_in_foreach",
             ),
             ),
+            rx.box(
+                rx.foreach(VarOperationState.optional_list, rx.text.span),
+                id="optional_list",
+            ),
+            rx.box(
+                rx.foreach(VarOperationState.optional_dict, rx.text.span),
+                id="optional_dict",
+            ),
+            rx.box(
+                rx.foreach(VarOperationState.optional_list_value, rx.text.span),
+                id="optional_list_value",
+            ),
+            rx.box(
+                rx.foreach(VarOperationState.optional_dict_value, rx.text.span),
+                id="optional_dict_value",
+            ),
         )
         )
 
 
 
 

+ 29 - 12
tests/units/components/core/test_foreach.py

@@ -1,6 +1,7 @@
 import pydantic.v1
 import pydantic.v1
 import pytest
 import pytest
 
 
+import reflex as rx
 from reflex import el
 from reflex import el
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.component import Component
 from reflex.components.component import Component
@@ -54,6 +55,11 @@ class ForEachState(BaseState):
 
 
     default_factory_list: list[ForEachTag] = pydantic.v1.Field(default_factory=list)
     default_factory_list: list[ForEachTag] = pydantic.v1.Field(default_factory=list)
 
 
+    optional_list: rx.Field[list[str] | None] = rx.field(None)
+    optional_list_value: rx.Field[list[str] | None] = rx.field(["red", "yellow"])
+    optional_dict: rx.Field[dict[str, str] | None] = rx.field(None)
+    optional_dict_value: rx.Field[dict[str, str] | None] = rx.field({"name": "red"})
+
 
 
 class ComponentStateTest(ComponentState):
 class ComponentStateTest(ComponentState):
     """A test component state."""
     """A test component state."""
@@ -145,7 +151,6 @@ seen_index_vars = set()
             display_color,
             display_color,
             {
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_list",
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_list",
-                "iterable_type": "list",
             },
             },
         ),
         ),
         (
         (
@@ -153,7 +158,6 @@ seen_index_vars = set()
             display_color_name,
             display_color_name,
             {
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list",
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list",
-                "iterable_type": "list",
             },
             },
         ),
         ),
         (
         (
@@ -161,7 +165,6 @@ seen_index_vars = set()
             display_shade,
             display_shade,
             {
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list",
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list",
-                "iterable_type": "list",
             },
             },
         ),
         ),
         (
         (
@@ -169,7 +172,6 @@ seen_index_vars = set()
             display_primary_colors,
             display_primary_colors,
             {
             {
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.primary_color)",
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.primary_color)",
-                "iterable_type": "list",
             },
             },
         ),
         ),
         (
         (
@@ -177,7 +179,6 @@ seen_index_vars = set()
             display_color_with_shades,
             display_color_with_shades,
             {
             {
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.color_with_shades)",
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.color_with_shades)",
-                "iterable_type": "list",
             },
             },
         ),
         ),
         (
         (
@@ -185,7 +186,6 @@ seen_index_vars = set()
             display_nested_color_with_shades,
             display_nested_color_with_shades,
             {
             {
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
-                "iterable_type": "list",
             },
             },
         ),
         ),
         (
         (
@@ -193,7 +193,6 @@ seen_index_vars = set()
             display_nested_color_with_shades_v2,
             display_nested_color_with_shades_v2,
             {
             {
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
                 "iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
-                "iterable_type": "list",
             },
             },
         ),
         ),
         (
         (
@@ -201,7 +200,6 @@ seen_index_vars = set()
             display_color_tuple,
             display_color_tuple,
             {
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.color_tuple",
                 "iterable_state": f"{ForEachState.get_full_name()}.color_tuple",
-                "iterable_type": "tuple",
             },
             },
         ),
         ),
         (
         (
@@ -209,7 +207,6 @@ seen_index_vars = set()
             display_colors_set,
             display_colors_set,
             {
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_set",
                 "iterable_state": f"{ForEachState.get_full_name()}.colors_set",
-                "iterable_type": "set",
             },
             },
         ),
         ),
         (
         (
@@ -217,7 +214,6 @@ seen_index_vars = set()
             lambda el, i: display_nested_list_element(el, i),
             lambda el, i: display_nested_list_element(el, i),
             {
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list",
                 "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list",
-                "iterable_type": "list",
             },
             },
         ),
         ),
         (
         (
@@ -225,7 +221,6 @@ seen_index_vars = set()
             display_color_index_tuple,
             display_color_index_tuple,
             {
             {
                 "iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple",
                 "iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple",
-                "iterable_type": "tuple",
             },
             },
         ),
         ),
     ],
     ],
@@ -242,7 +237,6 @@ def test_foreach_render(state_var, render_fn, render_dict):
 
 
     rend = component.render()
     rend = component.render()
     assert rend["iterable_state"] == render_dict["iterable_state"]
     assert rend["iterable_state"] == render_dict["iterable_state"]
-    assert rend["iterable_type"] == render_dict["iterable_type"]
 
 
     # Make sure the index vars are unique.
     # Make sure the index vars are unique.
     arg_index = rend["arg_index"]
     arg_index = rend["arg_index"]
@@ -306,3 +300,26 @@ def test_foreach_default_factory():
         ForEachState.default_factory_list,
         ForEachState.default_factory_list,
         lambda tag: text(tag.name),
         lambda tag: text(tag.name),
     )
     )
+
+
+def test_optional_list():
+    """Test that the foreach component works with optional lists."""
+    Foreach.create(
+        ForEachState.optional_list,
+        lambda color: text(color),
+    )
+
+    Foreach.create(
+        ForEachState.optional_list_value,
+        lambda color: text(color),
+    )
+
+    Foreach.create(
+        ForEachState.optional_dict,
+        lambda color: text(color[0], color[1]),
+    )
+
+    Foreach.create(
+        ForEachState.optional_dict_value,
+        lambda color: text(color[0], color[1]),
+    )

+ 3 - 3
tests/units/vars/test_object.py

@@ -136,7 +136,7 @@ def test_typing() -> None:
     var = ObjectState.base
     var = ObjectState.base
     _ = assert_type(var, ObjectVar[Base])
     _ = assert_type(var, ObjectVar[Base])
     optional_var = ObjectState.base_optional
     optional_var = ObjectState.base_optional
-    _ = assert_type(optional_var, ObjectVar[Base | None])
+    _ = assert_type(optional_var, ObjectVar[Base])
     list_var = ObjectState.base_list
     list_var = ObjectState.base_list
     _ = assert_type(list_var, ArrayVar[Sequence[Base]])
     _ = assert_type(list_var, ArrayVar[Sequence[Base]])
     list_var_0 = list_var[0]
     list_var_0 = list_var[0]
@@ -146,7 +146,7 @@ def test_typing() -> None:
     var = ObjectState.sqlamodel
     var = ObjectState.sqlamodel
     _ = assert_type(var, ObjectVar[SqlaModel])
     _ = assert_type(var, ObjectVar[SqlaModel])
     optional_var = ObjectState.sqlamodel_optional
     optional_var = ObjectState.sqlamodel_optional
-    _ = assert_type(optional_var, ObjectVar[SqlaModel | None])
+    _ = assert_type(optional_var, ObjectVar[SqlaModel])
     list_var = ObjectState.base_list
     list_var = ObjectState.base_list
     _ = assert_type(list_var, ArrayVar[Sequence[Base]])
     _ = assert_type(list_var, ArrayVar[Sequence[Base]])
     list_var_0 = list_var[0]
     list_var_0 = list_var[0]
@@ -156,7 +156,7 @@ def test_typing() -> None:
     var = ObjectState.dataclass
     var = ObjectState.dataclass
     _ = assert_type(var, ObjectVar[Dataclass])
     _ = assert_type(var, ObjectVar[Dataclass])
     optional_var = ObjectState.dataclass_optional
     optional_var = ObjectState.dataclass_optional
-    _ = assert_type(optional_var, ObjectVar[Dataclass | None])
+    _ = assert_type(optional_var, ObjectVar[Dataclass])
     list_var = ObjectState.base_list
     list_var = ObjectState.base_list
     _ = assert_type(list_var, ArrayVar[Sequence[Base]])
     _ = assert_type(list_var, ArrayVar[Sequence[Base]])
     list_var_0 = list_var[0]
     list_var_0 = list_var[0]