Browse Source

more types

Khaleel Al-Adhami 6 months ago
parent
commit
079cc56f59
2 changed files with 67 additions and 15 deletions
  1. 55 13
      reflex/utils/types.py
  2. 12 2
      tests/units/utils/test_utils.py

+ 55 - 13
reflex/utils/types.py

@@ -7,6 +7,7 @@ import dataclasses
 import inspect
 import sys
 import types
+from collections import abc
 from functools import cached_property, lru_cache, wraps
 from typing import (
     TYPE_CHECKING,
@@ -21,6 +22,7 @@ from typing import (
     Sequence,
     Tuple,
     Type,
+    TypeVar,
     Union,
     _GenericAlias,  # type: ignore
     get_args,
@@ -29,6 +31,7 @@ from typing import (
 from typing import get_origin as get_origin_og
 
 import sqlalchemy
+import typing_extensions
 
 import reflex
 from reflex.components.core.breakpoints import Breakpoints
@@ -810,24 +813,63 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
     provided_args = get_args(possible_subclass)
     accepted_args = get_args(possible_superclass)
 
-    if accepted_type_origin is Union:
-        if provided_type_origin is not Union:
-            return any(
-                typehint_issubclass(possible_subclass, accepted_arg)
-                for accepted_arg in accepted_args
-            )
+    if provided_type_origin is Union:
         return all(
-            any(
-                typehint_issubclass(provided_arg, accepted_arg)
-                for accepted_arg in accepted_args
-            )
+            typehint_issubclass(provided_arg, possible_superclass)
             for provided_arg in provided_args
         )
 
+    if accepted_type_origin is Union:
+        return any(
+            typehint_issubclass(possible_subclass, accepted_arg)
+            for accepted_arg in accepted_args
+        )
+
+    # Check specifically for Sequence and Iterable
+    if (accepted_type_origin or possible_superclass) in (
+        Sequence,
+        abc.Sequence,
+        Iterable,
+        abc.Iterable,
+    ):
+        iterable_type = accepted_args[0] if accepted_args else Any
+
+        if provided_type_origin is None:
+            if not issubclass(
+                possible_subclass, (accepted_type_origin or possible_superclass)
+            ):
+                return False
+
+            if issubclass(possible_subclass, str) and not isinstance(
+                iterable_type, TypeVar
+            ):
+                return typehint_issubclass(str, iterable_type)
+
+        if not issubclass(
+            provided_type_origin, (accepted_type_origin or possible_superclass)
+        ):
+            return False
+
+        if not isinstance(iterable_type, (TypeVar, typing_extensions.TypeVar)):
+            if provided_type_origin in (list, tuple, set):
+                # Ensure all specific types are compatible with accepted types
+                return all(
+                    typehint_issubclass(provided_arg, iterable_type)
+                    for provided_arg in provided_args
+                    if provided_arg is not ...  # Ellipsis in Tuples
+                )
+            if possible_subclass is dict:
+                # Ensure all specific types are compatible with accepted types
+                return all(
+                    typehint_issubclass(provided_arg, iterable_type)
+                    for provided_arg in provided_args[:1]
+                )
+        return True
+
     # Check if the origin of both types is the same (e.g., list for List[int])
-    # This probably should be issubclass instead of ==
-    if (provided_type_origin or possible_subclass) != (
-        accepted_type_origin or possible_superclass
+    if not issubclass(
+        provided_type_origin or possible_subclass,
+        accepted_type_origin or possible_superclass,
     ):
         return False
 

+ 12 - 2
tests/units/utils/test_utils.py

@@ -2,7 +2,7 @@ import os
 import typing
 from functools import cached_property
 from pathlib import Path
-from typing import Any, ClassVar, Dict, List, Literal, Type, Union
+from typing import Any, ClassVar, Dict, List, Literal, Sequence, Tuple, Type, Union
 
 import pytest
 import typer
@@ -109,10 +109,20 @@ def test_is_generic_alias(cls: type, expected: bool):
         (Dict[str, str], dict[str, str], True),
         (Dict[str, str], dict[str, Any], True),
         (Dict[str, Any], dict[str, Any], True),
+        (List[int], Sequence[int], True),
+        (List[str], Sequence[int], False),
+        (Tuple[int], Sequence[int], True),
+        (Tuple[int, str], Sequence[int], False),
+        (Tuple[int, ...], Sequence[int], True),
+        (str, Sequence[int], False),
+        (str, Sequence[str], True),
     ],
 )
 def test_typehint_issubclass(subclass, superclass, expected):
-    assert types.typehint_issubclass(subclass, superclass) == expected
+    if expected:
+        assert types.typehint_issubclass(subclass, superclass)
+    else:
+        assert not types.typehint_issubclass(subclass, superclass)
 
 
 def test_validate_invalid_bun_path(mocker):