Переглянути джерело

make object var handle all mapping instead of just dict (#4602)

* make object var handle all mapping instead of just dict

* unbreak ci

* get it right pyright

* create generic variable for field

* add support for typeddict (to some degree)

* import from extensions
Khaleel Al-Adhami 4 місяців тому
батько
коміт
bea266b8ed

+ 16 - 0
reflex/utils/types.py

@@ -829,6 +829,22 @@ StateBases = get_base_class(StateVar)
 StateIterBases = get_base_class(StateIterVar)
 
 
+def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
+    """Check if a class is a subclass of another class. Returns False if internal error occurs.
+
+    Args:
+        cls: The class to check.
+        cls_check: The class to check against.
+
+    Returns:
+        Whether the class is a subclass of the other class.
+    """
+    try:
+        return issubclass(cls, cls_check)
+    except TypeError:
+        return False
+
+
 def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
     """Check if a type hint is a subclass of another type hint.
 

+ 29 - 20
reflex/vars/base.py

@@ -26,6 +26,7 @@ from typing import (
     Iterable,
     List,
     Literal,
+    Mapping,
     NoReturn,
     Optional,
     Set,
@@ -64,6 +65,7 @@ from reflex.utils.types import (
     _isinstance,
     get_origin,
     has_args,
+    safe_issubclass,
     unionize,
 )
 
@@ -127,7 +129,7 @@ class VarData:
         state: str = "",
         field_name: str = "",
         imports: ImportDict | ParsedImportDict | None = None,
-        hooks: dict[str, VarData | None] | None = None,
+        hooks: Mapping[str, VarData | None] | None = None,
         deps: list[Var] | None = None,
         position: Hooks.HookPosition | None = None,
     ):
@@ -643,8 +645,8 @@ class Var(Generic[VAR_TYPE]):
     @overload
     def to(
         self,
-        output: type[dict],
-    ) -> ObjectVar[dict]: ...
+        output: type[Mapping],
+    ) -> ObjectVar[Mapping]: ...
 
     @overload
     def to(
@@ -686,7 +688,9 @@ class Var(Generic[VAR_TYPE]):
 
         # If the first argument is a python type, we map it to the corresponding Var type.
         for var_subclass in _var_subclasses[::-1]:
-            if fixed_output_type in var_subclass.python_types:
+            if fixed_output_type in var_subclass.python_types or safe_issubclass(
+                fixed_output_type, var_subclass.python_types
+            ):
                 return self.to(var_subclass.var_subclass, output)
 
         if fixed_output_type is None:
@@ -820,7 +824,7 @@ class Var(Generic[VAR_TYPE]):
             return False
         if issubclass(type_, list):
             return []
-        if issubclass(type_, dict):
+        if issubclass(type_, Mapping):
             return {}
         if issubclass(type_, tuple):
             return ()
@@ -1026,7 +1030,7 @@ class Var(Generic[VAR_TYPE]):
                     f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
                 }
             ),
-        ).to(ObjectVar, Dict[str, str])
+        ).to(ObjectVar, Mapping[str, str])
         return refs[LiteralVar.create(str(self))]
 
     @deprecated("Use `.js_type()` instead.")
@@ -1373,7 +1377,7 @@ class LiteralVar(Var):
 
         serialized_value = serializers.serialize(value)
         if serialized_value is not None:
-            if isinstance(serialized_value, dict):
+            if isinstance(serialized_value, Mapping):
                 return LiteralObjectVar.create(
                     serialized_value,
                     _var_type=type(value),
@@ -1498,7 +1502,7 @@ def var_operation(
 ) -> Callable[P, ArrayVar[LIST_T]]: ...
 
 
-OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
+OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
 
 
 @overload
@@ -1573,8 +1577,8 @@ def figure_out_type(value: Any) -> types.GenericType:
         return Set[unionize(*(figure_out_type(v) for v in value))]
     if isinstance(value, tuple):
         return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
-    if isinstance(value, dict):
-        return Dict[
+    if isinstance(value, Mapping):
+        return Mapping[
             unionize(*(figure_out_type(k) for k in value)),
             unionize(*(figure_out_type(v) for v in value.values())),
         ]
@@ -2002,10 +2006,10 @@ class ComputedVar(Var[RETURN_TYPE]):
 
     @overload
     def __get__(
-        self: ComputedVar[dict[DICT_KEY, DICT_VAL]],
+        self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]],
         instance: None,
         owner: Type,
-    ) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ...
+    ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
 
     @overload
     def __get__(
@@ -2915,11 +2919,14 @@ V = TypeVar("V")
 
 BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
 
+FIELD_TYPE = TypeVar("FIELD_TYPE")
+MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
+
 
-class Field(Generic[T]):
+class Field(Generic[FIELD_TYPE]):
     """Shadow class for Var to allow for type hinting in the IDE."""
 
-    def __set__(self, instance, value: T):
+    def __set__(self, instance, value: FIELD_TYPE):
         """Set the Var.
 
         Args:
@@ -2931,7 +2938,9 @@ class Field(Generic[T]):
     def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
 
     @overload
-    def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
+    def __get__(
+        self: Field[int] | Field[float] | Field[int | float], instance: None, owner
+    ) -> NumberVar: ...
 
     @overload
     def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
@@ -2948,8 +2957,8 @@ class Field(Generic[T]):
 
     @overload
     def __get__(
-        self: Field[Dict[str, V]], instance: None, owner
-    ) -> ObjectVar[Dict[str, V]]: ...
+        self: Field[MAPPING_TYPE], instance: None, owner
+    ) -> ObjectVar[MAPPING_TYPE]: ...
 
     @overload
     def __get__(
@@ -2957,10 +2966,10 @@ class Field(Generic[T]):
     ) -> ObjectVar[BASE_TYPE]: ...
 
     @overload
-    def __get__(self, instance: None, owner) -> Var[T]: ...
+    def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ...
 
     @overload
-    def __get__(self, instance, owner) -> T: ...
+    def __get__(self, instance, owner) -> FIELD_TYPE: ...
 
     def __get__(self, instance, owner):  # type: ignore
         """Get the Var.
@@ -2971,7 +2980,7 @@ class Field(Generic[T]):
         """
 
 
-def field(value: T) -> Field[T]:
+def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
     """Create a Field with a value.
 
     Args:

+ 43 - 32
reflex/vars/object.py

@@ -8,8 +8,8 @@ import typing
 from inspect import isclass
 from typing import (
     Any,
-    Dict,
     List,
+    Mapping,
     NoReturn,
     Tuple,
     Type,
@@ -19,6 +19,8 @@ from typing import (
     overload,
 )
 
+from typing_extensions import is_typeddict
+
 from reflex.utils import types
 from reflex.utils.exceptions import VarAttributeError
 from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
@@ -36,7 +38,7 @@ from .base import (
 from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
 from .sequence import ArrayVar, StringVar
 
-OBJECT_TYPE = TypeVar("OBJECT_TYPE")
+OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
 
 KEY_TYPE = TypeVar("KEY_TYPE")
 VALUE_TYPE = TypeVar("VALUE_TYPE")
@@ -46,7 +48,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
 OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
 
 
-class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
+class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
     """Base class for immutable object vars."""
 
     def _key_type(self) -> Type:
@@ -59,7 +61,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
 
     @overload
     def _value_type(
-        self: ObjectVar[Dict[Any, VALUE_TYPE]],
+        self: ObjectVar[Mapping[Any, VALUE_TYPE]],
     ) -> Type[VALUE_TYPE]: ...
 
     @overload
@@ -74,7 +76,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
         fixed_type = get_origin(self._var_type) or self._var_type
         if not isclass(fixed_type):
             return Any
-        args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
+        args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else ()
         return args[1] if args else Any
 
     def keys(self) -> ArrayVar[List[str]]:
@@ -87,7 +89,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
 
     @overload
     def values(
-        self: ObjectVar[Dict[Any, VALUE_TYPE]],
+        self: ObjectVar[Mapping[Any, VALUE_TYPE]],
     ) -> ArrayVar[List[VALUE_TYPE]]: ...
 
     @overload
@@ -103,7 +105,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
 
     @overload
     def entries(
-        self: ObjectVar[Dict[Any, VALUE_TYPE]],
+        self: ObjectVar[Mapping[Any, VALUE_TYPE]],
     ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
 
     @overload
@@ -133,49 +135,55 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
     # NoReturn is used here to catch when key value is Any
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[Any, NoReturn]],
+        self: ObjectVar[Mapping[Any, NoReturn]],
         key: Var | Any,
     ) -> Var: ...
 
+    @overload
+    def __getitem__(
+        self: (ObjectVar[Mapping[Any, bool]]),
+        key: Var | Any,
+    ) -> BooleanVar: ...
+
     @overload
     def __getitem__(
         self: (
-            ObjectVar[Dict[Any, int]]
-            | ObjectVar[Dict[Any, float]]
-            | ObjectVar[Dict[Any, int | float]]
+            ObjectVar[Mapping[Any, int]]
+            | ObjectVar[Mapping[Any, float]]
+            | ObjectVar[Mapping[Any, int | float]]
         ),
         key: Var | Any,
     ) -> NumberVar: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[Any, str]],
+        self: ObjectVar[Mapping[Any, str]],
         key: Var | Any,
     ) -> StringVar: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
         key: Var | Any,
     ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
         key: Var | Any,
     ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
+        self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
         key: Var | Any,
     ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
 
     @overload
     def __getitem__(
-        self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
+        self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
         key: Var | Any,
-    ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
+    ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
 
     def __getitem__(self, key: Var | Any) -> Var:
         """Get an item from the object.
@@ -195,49 +203,49 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
     # NoReturn is used here to catch when key value is Any
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[Any, NoReturn]],
+        self: ObjectVar[Mapping[Any, NoReturn]],
         name: str,
     ) -> Var: ...
 
     @overload
     def __getattr__(
         self: (
-            ObjectVar[Dict[Any, int]]
-            | ObjectVar[Dict[Any, float]]
-            | ObjectVar[Dict[Any, int | float]]
+            ObjectVar[Mapping[Any, int]]
+            | ObjectVar[Mapping[Any, float]]
+            | ObjectVar[Mapping[Any, int | float]]
         ),
         name: str,
     ) -> NumberVar: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[Any, str]],
+        self: ObjectVar[Mapping[Any, str]],
         name: str,
     ) -> StringVar: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
         name: str,
     ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
+        self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
         name: str,
     ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
+        self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
         name: str,
     ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
 
     @overload
     def __getattr__(
-        self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
+        self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
         name: str,
-    ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
+    ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
 
     @overload
     def __getattr__(
@@ -266,8 +274,11 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
             var_type = get_args(var_type)[0]
 
         fixed_type = var_type if isclass(var_type) else get_origin(var_type)
-        if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
-            fixed_type in types.UnionTypes
+
+        if (
+            (isclass(fixed_type) and not issubclass(fixed_type, Mapping))
+            or (fixed_type in types.UnionTypes)
+            or is_typeddict(fixed_type)
         ):
             attribute_type = get_attribute_access_type(var_type, name)
             if attribute_type is None:
@@ -299,7 +310,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
 class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
     """Base class for immutable literal object vars."""
 
-    _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
+    _var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
         default_factory=dict
     )
 
@@ -383,7 +394,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
     @classmethod
     def create(
         cls,
-        _var_value: dict,
+        _var_value: Mapping,
         _var_type: Type[OBJECT_TYPE] | None = None,
         _var_data: VarData | None = None,
     ) -> LiteralObjectVar[OBJECT_TYPE]:
@@ -466,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
     """
     return var_operation_return(
         js_expression=f"({{...{lhs}, ...{rhs}}})",
-        var_type=Dict[
+        var_type=Mapping[
             Union[lhs._key_type(), rhs._key_type()],
             Union[lhs._value_type(), rhs._value_type()],
         ],

+ 1 - 1
reflex/vars/sequence.py

@@ -987,7 +987,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
             raise_unsupported_operand_types("[]", (type(self), type(i)))
         return array_item_operation(self, i)
 
-    def length(self) -> NumberVar:
+    def length(self) -> NumberVar[int]:
         """Get the length of the array.
 
         Returns:

+ 2 - 2
tests/units/components/core/test_match.py

@@ -1,4 +1,4 @@
-from typing import Dict, List, Tuple
+from typing import List, Mapping, Tuple
 
 import pytest
 
@@ -67,7 +67,7 @@ def test_match_components():
     assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
 
     assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
-    assert match_cases[4][0]._var_type == Dict[str, str]
+    assert match_cases[4][0]._var_type == Mapping[str, str]
     fifth_return_value_render = match_cases[4][1].render()
     assert fifth_return_value_render["name"] == "RadixThemesText"
     assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'

+ 2 - 2
tests/units/test_style.py

@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from typing import Any, Dict
+from typing import Any, Mapping
 
 import pytest
 
@@ -379,7 +379,7 @@ class StyleState(rx.State):
             {
                 "css": Var(
                     _js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})'
-                ).to(Dict[str, str])
+                ).to(Mapping[str, str])
             },
         ),
         (

+ 2 - 2
tests/units/test_var.py

@@ -2,7 +2,7 @@ import json
 import math
 import sys
 import typing
-from typing import Dict, List, Optional, Set, Tuple, Union, cast
+from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
 
 import pytest
 from pandas import DataFrame
@@ -270,7 +270,7 @@ def test_get_setter(prop: Var, expected):
         ([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])),
         (
             {"a": 1, "b": 2},
-            Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]),
+            Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]),
         ),
     ],
 )

+ 5 - 5
tests/units/vars/test_base.py

@@ -1,4 +1,4 @@
-from typing import Dict, List, Union
+from typing import List, Mapping, Union
 
 import pytest
 
@@ -37,12 +37,12 @@ class ChildGenericDict(GenericDict):
         ("a", str),
         ([1, 2, 3], List[int]),
         ([1, 2.0, "a"], List[Union[int, float, str]]),
-        ({"a": 1, "b": 2}, Dict[str, int]),
-        ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
+        ({"a": 1, "b": 2}, Mapping[str, int]),
+        ({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]),
         (CustomDict(), CustomDict),
         (ChildCustomDict(), ChildCustomDict),
-        (GenericDict({1: 1}), Dict[int, int]),
-        (ChildGenericDict({1: 1}), Dict[int, int]),
+        (GenericDict({1: 1}), Mapping[int, int]),
+        (ChildGenericDict({1: 1}), Mapping[int, int]),
     ],
 )
 def test_figure_out_type(value, expected):