Selaa lähdekoodia

add typed dict type checking (#4340)

* add typed dict type checking

* technically it has to be a mapping, not specifically a dict
Khaleel Al-Adhami 6 kuukautta sitten
vanhempi
säilyke
bcea79cd45
2 muutettua tiedostoa jossa 66 lisäystä ja 3 poistoa
  1. 18 2
      reflex/components/component.py
  2. 48 1
      reflex/utils/types.py

+ 18 - 2
reflex/components/component.py

@@ -186,6 +186,23 @@ ComponentStyle = Dict[
 ComponentChild = Union[types.PrimitiveType, Var, BaseComponent]
 
 
+def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
+    """Check if an object satisfies a type hint.
+
+    Args:
+        obj: The object to check.
+        type_hint: The type hint to check against.
+
+    Returns:
+        Whether the object satisfies the type hint.
+    """
+    if isinstance(obj, LiteralVar):
+        return types._isinstance(obj._var_value, type_hint)
+    if isinstance(obj, Var):
+        return types._issubclass(obj._var_type, type_hint)
+    return types._isinstance(obj, type_hint)
+
+
 class Component(BaseComponent, ABC):
     """A component with style, event trigger and other props."""
 
@@ -460,8 +477,7 @@ class Component(BaseComponent, ABC):
                     )
                 ) or (
                     # Else just check if the passed var type is valid.
-                    not passed_types
-                    and not types._issubclass(passed_type, expected_type, value)
+                    not passed_types and not satisfies_type_hint(value, expected_type)
                 ):
                     value_name = value._js_expr if isinstance(value, Var) else value
 

+ 48 - 1
reflex/utils/types.py

@@ -14,9 +14,11 @@ from typing import (
     Callable,
     ClassVar,
     Dict,
+    FrozenSet,
     Iterable,
     List,
     Literal,
+    Mapping,
     Optional,
     Sequence,
     Tuple,
@@ -29,6 +31,7 @@ from typing import (
 from typing import get_origin as get_origin_og
 
 import sqlalchemy
+from typing_extensions import is_typeddict
 
 import reflex
 from reflex.components.core.breakpoints import Breakpoints
@@ -494,6 +497,14 @@ def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None)
     if isinstance(instance, Breakpoints):
         return _breakpoints_satisfies_typing(cls_check, instance)
 
+    if isinstance(cls_check_base, tuple):
+        cls_check_base = tuple(
+            cls_check_one if not is_typeddict(cls_check_one) else dict
+            for cls_check_one in cls_check_base
+        )
+    if is_typeddict(cls_check_base):
+        cls_check_base = dict
+
     # Check if the types match.
     try:
         return cls_check_base == Any or issubclass(cls_base, cls_check_base)
@@ -503,6 +514,36 @@ def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None)
         raise TypeError(f"Invalid type for issubclass: {cls_base}") from te
 
 
+def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
+    """Check if an object satisfies a typed dict.
+
+    Args:
+        obj: The object to check.
+        cls: The typed dict to check against.
+
+    Returns:
+        Whether the object satisfies the typed dict.
+    """
+    if not isinstance(obj, Mapping):
+        return False
+
+    key_names_to_values = get_type_hints(cls)
+    required_keys: FrozenSet[str] = getattr(cls, "__required_keys__", frozenset())
+
+    if not all(
+        isinstance(key, str)
+        and key in key_names_to_values
+        and _isinstance(value, key_names_to_values[key])
+        for key, value in obj.items()
+    ):
+        return False
+
+    # TODO in 3.14: Implement https://peps.python.org/pep-0728/ if it's approved
+
+    # required keys are all present
+    return required_keys.issubset(required_keys)
+
+
 def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
     """Check if an object is an instance of a class.
 
@@ -529,6 +570,12 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
     origin = get_origin(cls)
 
     if origin is None:
+        # cls is a typed dict
+        if is_typeddict(cls):
+            if nested:
+                return does_obj_satisfy_typed_dict(obj, cls)
+            return isinstance(obj, dict)
+
         # cls is a simple class
         return isinstance(obj, cls)
 
@@ -553,7 +600,7 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
                 and len(obj) == len(args)
                 and all(_isinstance(item, arg) for item, arg in zip(obj, args))
             )
-        if origin is dict:
+        if origin in (dict, Breakpoints):
             return isinstance(obj, dict) and all(
                 _isinstance(key, args[0]) and _isinstance(value, args[1])
                 for key, value in obj.items()