Browse Source

add type hinting to existing types (#3729)

* add type hinting to existing types

* dang it darglint

* i cannot
Khaleel Al-Adhami 9 months ago
parent
commit
ad14f38329

+ 51 - 1
reflex/experimental/vars/base.py

@@ -10,9 +10,15 @@ from typing import (
     TYPE_CHECKING,
     Any,
     Callable,
+    Dict,
+    Generic,
+    List,
     Optional,
+    Set,
+    Tuple,
     Type,
     TypeVar,
+    Union,
     overload,
 )
 
@@ -42,13 +48,15 @@ if TYPE_CHECKING:
     from .object import ObjectVar, ToObjectOperation
     from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
 
+VAR_TYPE = TypeVar("VAR_TYPE")
+
 
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
-class ImmutableVar(Var):
+class ImmutableVar(Var, Generic[VAR_TYPE]):
     """Base class for immutable vars."""
 
     # The name of the var.
@@ -405,6 +413,8 @@ class ImmutableVar(Var):
             return self.to(ArrayVar, var_type)
         if issubclass(fixed_type, str):
             return self.to(StringVar)
+        if issubclass(fixed_type, Base):
+            return self.to(ObjectVar, var_type)
         return self
 
 
@@ -531,3 +541,43 @@ def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P
         return wrapper
 
     return decorator
+
+
+def unionize(*args: Type) -> Type:
+    """Unionize the types.
+
+    Args:
+        args: The types to unionize.
+
+    Returns:
+        The unionized types.
+    """
+    if not args:
+        return Any
+    first, *rest = args
+    if not rest:
+        return first
+    return Union[first, unionize(*rest)]
+
+
+def figure_out_type(value: Any) -> Type:
+    """Figure out the type of the value.
+
+    Args:
+        value: The value to figure out the type of.
+
+    Returns:
+        The type of the value.
+    """
+    if isinstance(value, list):
+        return List[unionize(*(figure_out_type(v) for v in value))]
+    if isinstance(value, set):
+        return Set[unionize(*(figure_out_type(v) for v in value))]
+    if isinstance(value, tuple):
+        return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
+    if isinstance(value, dict):
+        return Dict[
+            unionize(*(figure_out_type(k) for k in value)),
+            unionize(*(figure_out_type(v) for v in value.values())),
+        ]
+    return type(value)

+ 1 - 1
reflex/experimental/vars/function.py

@@ -11,7 +11,7 @@ from reflex.experimental.vars.base import ImmutableVar, LiteralVar
 from reflex.vars import ImmutableVarData, Var, VarData
 
 
-class FunctionVar(ImmutableVar):
+class FunctionVar(ImmutableVar[Callable]):
     """Base class for immutable function vars."""
 
     def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:

+ 2 - 2
reflex/experimental/vars/number.py

@@ -15,7 +15,7 @@ from reflex.experimental.vars.base import (
 from reflex.vars import ImmutableVarData, Var, VarData
 
 
-class NumberVar(ImmutableVar):
+class NumberVar(ImmutableVar[Union[int, float]]):
     """Base class for immutable number vars."""
 
     def __add__(self, other: number_types | boolean_types) -> NumberAddOperation:
@@ -693,7 +693,7 @@ class NumberTruncOperation(UnaryNumberOperation):
         return f"Math.trunc({str(value)})"
 
 
-class BooleanVar(ImmutableVar):
+class BooleanVar(ImmutableVar[bool]):
     """Base class for immutable boolean vars."""
 
     def __and__(self, other: bool) -> BooleanAndOperation:

+ 201 - 24
reflex/experimental/vars/object.py

@@ -6,23 +6,69 @@ import dataclasses
 import sys
 import typing
 from functools import cached_property
-from typing import Any, Dict, List, Tuple, Type, Union
+from inspect import isclass
+from typing import (
+    Any,
+    Dict,
+    List,
+    NoReturn,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    get_args,
+    overload,
+)
+
+from typing_extensions import get_origin
 
-from reflex.experimental.vars.base import ImmutableVar, LiteralVar
-from reflex.experimental.vars.sequence import ArrayVar, unionize
+from reflex.experimental.vars.base import (
+    ImmutableVar,
+    LiteralVar,
+    figure_out_type,
+)
+from reflex.experimental.vars.number import NumberVar
+from reflex.experimental.vars.sequence import ArrayVar, StringVar
+from reflex.utils.exceptions import VarAttributeError
+from reflex.utils.types import GenericType, get_attribute_access_type
 from reflex.vars import ImmutableVarData, Var, VarData
 
+OBJECT_TYPE = TypeVar("OBJECT_TYPE")
+
+KEY_TYPE = TypeVar("KEY_TYPE")
+VALUE_TYPE = TypeVar("VALUE_TYPE")
+
+ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
 
-class ObjectVar(ImmutableVar):
+OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
+
+
+class ObjectVar(ImmutableVar[OBJECT_TYPE]):
     """Base class for immutable object vars."""
 
+    @overload
+    def _key_type(self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]) -> KEY_TYPE: ...
+
+    @overload
+    def _key_type(self) -> Type: ...
+
     def _key_type(self) -> Type:
         """Get the type of the keys of the object.
 
         Returns:
             The type of the keys of the object.
         """
-        return ImmutableVar
+        fixed_type = (
+            self._var_type if isclass(self._var_type) else get_origin(self._var_type)
+        )
+        args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
+        return args[0] if args else Any
+
+    @overload
+    def _value_type(self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]) -> VALUE_TYPE: ...
+
+    @overload
+    def _value_type(self) -> Type: ...
 
     def _value_type(self) -> Type:
         """Get the type of the values of the object.
@@ -30,9 +76,21 @@ class ObjectVar(ImmutableVar):
         Returns:
             The type of the values of the object.
         """
-        return ImmutableVar
+        fixed_type = (
+            self._var_type if isclass(self._var_type) else get_origin(self._var_type)
+        )
+        args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
+        return args[1] if args else Any
+
+    @overload
+    def keys(
+        self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
+    ) -> ArrayVar[List[KEY_TYPE]]: ...
+
+    @overload
+    def keys(self) -> ArrayVar: ...
 
-    def keys(self) -> ObjectKeysOperation:
+    def keys(self) -> ArrayVar:
         """Get the keys of the object.
 
         Returns:
@@ -40,7 +98,15 @@ class ObjectVar(ImmutableVar):
         """
         return ObjectKeysOperation(self)
 
-    def values(self) -> ObjectValuesOperation:
+    @overload
+    def values(
+        self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
+    ) -> ArrayVar[List[VALUE_TYPE]]: ...
+
+    @overload
+    def values(self) -> ArrayVar: ...
+
+    def values(self) -> ArrayVar:
         """Get the values of the object.
 
         Returns:
@@ -48,7 +114,15 @@ class ObjectVar(ImmutableVar):
         """
         return ObjectValuesOperation(self)
 
-    def entries(self) -> ObjectEntriesOperation:
+    @overload
+    def entries(
+        self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
+    ) -> ArrayVar[List[Tuple[KEY_TYPE, VALUE_TYPE]]]: ...
+
+    @overload
+    def entries(self) -> ArrayVar: ...
+
+    def entries(self) -> ArrayVar:
         """Get the entries of the object.
 
         Returns:
@@ -67,6 +141,53 @@ class ObjectVar(ImmutableVar):
         """
         return ObjectMergeOperation(self, other)
 
+    # NoReturn is used here to catch when key value is Any
+    @overload
+    def __getitem__(
+        self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
+        key: Var | Any,
+    ) -> ImmutableVar: ...
+
+    @overload
+    def __getitem__(
+        self: (
+            ObjectVar[Dict[KEY_TYPE, int]]
+            | ObjectVar[Dict[KEY_TYPE, float]]
+            | ObjectVar[Dict[KEY_TYPE, int | float]]
+        ),
+        key: Var | Any,
+    ) -> NumberVar: ...
+
+    @overload
+    def __getitem__(
+        self: ObjectVar[Dict[KEY_TYPE, str]],
+        key: Var | Any,
+    ) -> StringVar: ...
+
+    @overload
+    def __getitem__(
+        self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
+        key: Var | Any,
+    ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
+
+    @overload
+    def __getitem__(
+        self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
+        key: Var | Any,
+    ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
+
+    @overload
+    def __getitem__(
+        self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
+        key: Var | Any,
+    ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
+
+    @overload
+    def __getitem__(
+        self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
+        key: Var | Any,
+    ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
+
     def __getitem__(self, key: Var | Any) -> ImmutableVar:
         """Get an item from the object.
 
@@ -78,16 +199,78 @@ class ObjectVar(ImmutableVar):
         """
         return ObjectItemOperation(self, key).guess_type()
 
-    def __getattr__(self, name) -> ObjectItemOperation:
+    # NoReturn is used here to catch when key value is Any
+    @overload
+    def __getattr__(
+        self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
+        name: str,
+    ) -> ImmutableVar: ...
+
+    @overload
+    def __getattr__(
+        self: (
+            ObjectVar[Dict[KEY_TYPE, int]]
+            | ObjectVar[Dict[KEY_TYPE, float]]
+            | ObjectVar[Dict[KEY_TYPE, int | float]]
+        ),
+        name: str,
+    ) -> NumberVar: ...
+
+    @overload
+    def __getattr__(
+        self: ObjectVar[Dict[KEY_TYPE, str]],
+        name: str,
+    ) -> StringVar: ...
+
+    @overload
+    def __getattr__(
+        self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
+        name: str,
+    ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
+
+    @overload
+    def __getattr__(
+        self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
+        name: str,
+    ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
+
+    @overload
+    def __getattr__(
+        self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
+        name: str,
+    ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
+
+    @overload
+    def __getattr__(
+        self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
+        name: str,
+    ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
+
+    def __getattr__(self, name) -> ImmutableVar:
         """Get an attribute of the var.
 
         Args:
             name: The name of the attribute.
 
+        Raises:
+            VarAttributeError: The State var has no such attribute or may have been annotated wrongly.
+
         Returns:
             The attribute of the var.
         """
-        return ObjectItemOperation(self, name)
+        fixed_type = (
+            self._var_type if isclass(self._var_type) else get_origin(self._var_type)
+        )
+        if not issubclass(fixed_type, dict):
+            attribute_type = get_attribute_access_type(self._var_type, name)
+            if attribute_type is None:
+                raise VarAttributeError(
+                    f"The State var `{self._var_name}` has no attribute '{name}' or may have been annotated "
+                    f"wrongly."
+                )
+            return ObjectItemOperation(self, name, attribute_type).guess_type()
+        else:
+            return ObjectItemOperation(self, name).guess_type()
 
 
 @dataclasses.dataclass(
@@ -95,7 +278,7 @@ class ObjectVar(ImmutableVar):
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
-class LiteralObjectVar(LiteralVar, ObjectVar):
+class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]):
     """Base class for immutable literal object vars."""
 
     _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
@@ -103,9 +286,9 @@ class LiteralObjectVar(LiteralVar, ObjectVar):
     )
 
     def __init__(
-        self,
-        _var_value: dict[Var | Any, Var | Any],
-        _var_type: Type | None = None,
+        self: LiteralObjectVar[OBJECT_TYPE],
+        _var_value: OBJECT_TYPE,
+        _var_type: Type[OBJECT_TYPE] | None = None,
         _var_data: VarData | None = None,
     ):
         """Initialize the object var.
@@ -117,14 +300,7 @@ class LiteralObjectVar(LiteralVar, ObjectVar):
         """
         super(LiteralObjectVar, self).__init__(
             _var_name="",
-            _var_type=(
-                Dict[
-                    unionize(*map(type, _var_value.keys())),
-                    unionize(*map(type, _var_value.values())),
-                ]
-                if _var_type is None
-                else _var_type
-            ),
+            _var_type=(figure_out_type(_var_value) if _var_type is None else _var_type),
             _var_data=ImmutableVarData.merge(_var_data),
         )
         object.__setattr__(
@@ -489,6 +665,7 @@ class ObjectItemOperation(ImmutableVar):
         self,
         value: ObjectVar,
         key: Var | Any,
+        _var_type: GenericType | None = None,
         _var_data: VarData | None = None,
     ):
         """Initialize the object item operation.
@@ -500,7 +677,7 @@ class ObjectItemOperation(ImmutableVar):
         """
         super(ObjectItemOperation, self).__init__(
             _var_name="",
-            _var_type=value._value_type(),
+            _var_type=value._value_type() if _var_type is None else _var_type,
             _var_data=ImmutableVarData.merge(_var_data),
         )
         object.__setattr__(self, "value", value)

+ 134 - 36
reflex/experimental/vars/sequence.py

@@ -10,7 +10,18 @@ import re
 import sys
 import typing
 from functools import cached_property
-from typing import Any, List, Set, Tuple, Type, Union, overload
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    List,
+    Literal,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+    overload,
+)
 
 from typing_extensions import get_origin
 
@@ -19,6 +30,8 @@ from reflex.constants.base import REFLEX_VAR_OPENING_TAG
 from reflex.experimental.vars.base import (
     ImmutableVar,
     LiteralVar,
+    figure_out_type,
+    unionize,
 )
 from reflex.experimental.vars.number import (
     BooleanVar,
@@ -29,8 +42,11 @@ from reflex.experimental.vars.number import (
 from reflex.utils.types import GenericType
 from reflex.vars import ImmutableVarData, Var, VarData, _global_vars
 
+if TYPE_CHECKING:
+    from .object import ObjectVar
+
 
-class StringVar(ImmutableVar):
+class StringVar(ImmutableVar[str]):
     """Base class for immutable string vars."""
 
     def __add__(self, other: StringVar | str) -> ConcatVarOperation:
@@ -699,7 +715,17 @@ class ConcatVarOperation(StringVar):
         pass
 
 
-class ArrayVar(ImmutableVar):
+ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set])
+
+OTHER_TUPLE = TypeVar("OTHER_TUPLE")
+
+INNER_ARRAY_VAR = TypeVar("INNER_ARRAY_VAR")
+
+KEY_TYPE = TypeVar("KEY_TYPE")
+VALUE_TYPE = TypeVar("VALUE_TYPE")
+
+
+class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]):
     """Base class for immutable array vars."""
 
     from reflex.experimental.vars.sequence import StringVar
@@ -717,7 +743,7 @@ class ArrayVar(ImmutableVar):
 
         return ArrayJoinOperation(self, sep)
 
-    def reverse(self) -> ArrayReverseOperation:
+    def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]:
         """Reverse the array.
 
         Returns:
@@ -726,14 +752,98 @@ class ArrayVar(ImmutableVar):
         return ArrayReverseOperation(self)
 
     @overload
-    def __getitem__(self, i: slice) -> ArraySliceOperation: ...
+    def __getitem__(self, i: slice) -> ArrayVar[ARRAY_VAR_TYPE]: ...
+
+    @overload
+    def __getitem__(
+        self: (
+            ArrayVar[Tuple[int, OTHER_TUPLE]]
+            | ArrayVar[Tuple[float, OTHER_TUPLE]]
+            | ArrayVar[Tuple[int | float, OTHER_TUPLE]]
+        ),
+        i: Literal[0, -2],
+    ) -> NumberVar: ...
+
+    @overload
+    def __getitem__(
+        self: (
+            ArrayVar[Tuple[OTHER_TUPLE, int]]
+            | ArrayVar[Tuple[OTHER_TUPLE, float]]
+            | ArrayVar[Tuple[OTHER_TUPLE, int | float]]
+        ),
+        i: Literal[1, -1],
+    ) -> NumberVar: ...
+
+    @overload
+    def __getitem__(
+        self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2]
+    ) -> StringVar: ...
+
+    @overload
+    def __getitem__(
+        self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1]
+    ) -> StringVar: ...
+
+    @overload
+    def __getitem__(
+        self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2]
+    ) -> BooleanVar: ...
+
+    @overload
+    def __getitem__(
+        self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1]
+    ) -> BooleanVar: ...
+
+    @overload
+    def __getitem__(
+        self: (
+            ARRAY_VAR_OF_LIST_ELEMENT[int]
+            | ARRAY_VAR_OF_LIST_ELEMENT[float]
+            | ARRAY_VAR_OF_LIST_ELEMENT[int | float]
+        ),
+        i: int | NumberVar,
+    ) -> NumberVar: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar
+    ) -> StringVar: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
+    ) -> BooleanVar: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]],
+        i: int | NumberVar,
+    ) -> ArrayVar[List[INNER_ARRAY_VAR]]: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]],
+        i: int | NumberVar,
+    ) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[INNER_ARRAY_VAR, ...]],
+        i: int | NumberVar,
+    ) -> ArrayVar[Tuple[INNER_ARRAY_VAR, ...]]: ...
+
+    @overload
+    def __getitem__(
+        self: ARRAY_VAR_OF_LIST_ELEMENT[Dict[KEY_TYPE, VALUE_TYPE]],
+        i: int | NumberVar,
+    ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
 
     @overload
     def __getitem__(self, i: int | NumberVar) -> ImmutableVar: ...
 
     def __getitem__(
         self, i: slice | int | NumberVar
-    ) -> ArraySliceOperation | ImmutableVar:
+    ) -> ArrayVar[ARRAY_VAR_TYPE] | ImmutableVar:
         """Get a slice of the array.
 
         Args:
@@ -756,7 +866,7 @@ class ArrayVar(ImmutableVar):
 
     @overload
     @classmethod
-    def range(cls, stop: int | NumberVar, /) -> RangeOperation: ...
+    def range(cls, stop: int | NumberVar, /) -> ArrayVar[List[int]]: ...
 
     @overload
     @classmethod
@@ -766,7 +876,7 @@ class ArrayVar(ImmutableVar):
         end: int | NumberVar,
         step: int | NumberVar = 1,
         /,
-    ) -> RangeOperation: ...
+    ) -> ArrayVar[List[int]]: ...
 
     @classmethod
     def range(
@@ -774,7 +884,7 @@ class ArrayVar(ImmutableVar):
         first_endpoint: int | NumberVar,
         second_endpoint: int | NumberVar | None = None,
         step: int | NumberVar | None = None,
-    ) -> RangeOperation:
+    ) -> ArrayVar[List[int]]:
         """Create a range of numbers.
 
         Args:
@@ -794,7 +904,7 @@ class ArrayVar(ImmutableVar):
 
         return RangeOperation(start, end, step or 1)
 
-    def contains(self, other: Any) -> ArrayContainsOperation:
+    def contains(self, other: Any) -> BooleanVar:
         """Check if the array contains an element.
 
         Args:
@@ -806,12 +916,21 @@ class ArrayVar(ImmutableVar):
         return ArrayContainsOperation(self, other)
 
 
+LIST_ELEMENT = TypeVar("LIST_ELEMENT")
+
+ARRAY_VAR_OF_LIST_ELEMENT = Union[
+    ArrayVar[List[LIST_ELEMENT]],
+    ArrayVar[Set[LIST_ELEMENT]],
+    ArrayVar[Tuple[LIST_ELEMENT, ...]],
+]
+
+
 @dataclasses.dataclass(
     eq=False,
     frozen=True,
     **{"slots": True} if sys.version_info >= (3, 10) else {},
 )
-class LiteralArrayVar(LiteralVar, ArrayVar):
+class LiteralArrayVar(LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
     """Base class for immutable literal array vars."""
 
     _var_value: Union[
@@ -819,9 +938,9 @@ class LiteralArrayVar(LiteralVar, ArrayVar):
     ] = dataclasses.field(default_factory=list)
 
     def __init__(
-        self,
-        _var_value: list[Var | Any] | tuple[Var | Any, ...] | set[Var | Any],
-        _var_type: type[list] | type[tuple] | type[set] | None = None,
+        self: LiteralArrayVar[ARRAY_VAR_TYPE],
+        _var_value: ARRAY_VAR_TYPE,
+        _var_type: type[ARRAY_VAR_TYPE] | None = None,
         _var_data: VarData | None = None,
     ):
         """Initialize the array var.
@@ -834,11 +953,7 @@ class LiteralArrayVar(LiteralVar, ArrayVar):
         super(LiteralArrayVar, self).__init__(
             _var_name="",
             _var_data=ImmutableVarData.merge(_var_data),
-            _var_type=(
-                List[unionize(*map(type, _var_value))]
-                if _var_type is None
-                else _var_type
-            ),
+            _var_type=(figure_out_type(_var_value) if _var_type is None else _var_type),
         )
         object.__setattr__(self, "_var_value", _var_value)
         object.__delattr__(self, "_var_name")
@@ -1261,23 +1376,6 @@ class ArrayLengthOperation(ArrayToNumberOperation):
         return f"{str(self.a)}.length"
 
 
-def unionize(*args: Type) -> Type:
-    """Unionize the types.
-
-    Args:
-        args: The types to unionize.
-
-    Returns:
-        The unionized types.
-    """
-    if not args:
-        return Any
-    first, *rest = args
-    if not rest:
-        return first
-    return Union[first, unionize(*rest)]
-
-
 def is_tuple_type(t: GenericType) -> bool:
     """Check if a type is a tuple type.
 

+ 26 - 1
tests/test_var.py

@@ -1042,7 +1042,7 @@ def test_object_operations():
 
 def test_type_chains():
     object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3})
-    assert object_var._var_type is Dict[str, int]
+    assert (object_var._key_type(), object_var._value_type()) == (str, int)
     assert (object_var.keys()._var_type, object_var.values()._var_type) == (
         List[str],
         List[int],
@@ -1061,6 +1061,31 @@ def test_type_chains():
     )
 
 
+def test_nested_dict():
+    arr = LiteralArrayVar([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]])
+
+    assert (
+        str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)'
+    )
+
+
+def nested_base():
+    class Boo(Base):
+        foo: str
+        bar: int
+
+    class Foo(Base):
+        bar: Boo
+        baz: int
+
+    parent_obj = LiteralVar.create(Foo(bar=Boo(foo="bar", bar=5), baz=5))
+
+    assert (
+        str(parent_obj.bar.foo)
+        == '({ ["bar"] : ({ ["foo"] : "bar", ["bar"] : 5 }), ["baz"] : 5 })["bar"]["foo"]'
+    )
+
+
 def test_retrival():
     var_without_data = ImmutableVar.create("test")
     assert var_without_data is not None