|
@@ -120,6 +120,10 @@ class Unset:
|
|
|
|
|
|
|
|
|
@lru_cache
|
|
|
+def _get_origin_cached(tp: Any):
|
|
|
+ return get_origin_og(tp)
|
|
|
+
|
|
|
+
|
|
|
def get_origin(tp: Any):
|
|
|
"""Get the origin of a class.
|
|
|
|
|
@@ -129,7 +133,11 @@ def get_origin(tp: Any):
|
|
|
Returns:
|
|
|
The origin of the class.
|
|
|
"""
|
|
|
- return get_origin_og(tp)
|
|
|
+ return (
|
|
|
+ origin
|
|
|
+ if (origin := getattr(tp, "__origin__", None)) is not None
|
|
|
+ else _get_origin_cached(tp)
|
|
|
+ )
|
|
|
|
|
|
|
|
|
@lru_cache
|
|
@@ -190,7 +198,6 @@ def is_none(cls: GenericType) -> bool:
|
|
|
return cls is type(None) or cls is None
|
|
|
|
|
|
|
|
|
-@lru_cache
|
|
|
def is_union(cls: GenericType) -> bool:
|
|
|
"""Check if a class is a Union.
|
|
|
|
|
@@ -200,10 +207,12 @@ def is_union(cls: GenericType) -> bool:
|
|
|
Returns:
|
|
|
Whether the class is a Union.
|
|
|
"""
|
|
|
- return get_origin(cls) in UnionTypes
|
|
|
+ origin = getattr(cls, "__origin__", None)
|
|
|
+ if origin is Union:
|
|
|
+ return True
|
|
|
+ return origin is None and isinstance(cls, types.UnionType)
|
|
|
|
|
|
|
|
|
-@lru_cache
|
|
|
def is_literal(cls: GenericType) -> bool:
|
|
|
"""Check if a class is a Literal.
|
|
|
|
|
@@ -213,7 +222,7 @@ def is_literal(cls: GenericType) -> bool:
|
|
|
Returns:
|
|
|
Whether the class is a literal.
|
|
|
"""
|
|
|
- return get_origin(cls) is Literal
|
|
|
+ return getattr(cls, "__origin__", None) is Literal
|
|
|
|
|
|
|
|
|
def has_args(cls: type) -> bool:
|