Browse Source

use knowledge about generic types to improve their getters and checkers (#5245)

Khaleel Al-Adhami 1 week ago
parent
commit
244ed1f744
1 changed files with 14 additions and 5 deletions
  1. 14 5
      reflex/utils/types.py

+ 14 - 5
reflex/utils/types.py

@@ -120,6 +120,10 @@ class Unset:
 
 
 @lru_cache
+def _get_origin_cached(tp: Any):
+    return get_origin_og(tp)
+
+
 def get_origin(tp: Any):
     """Get the origin of a class.
 
@@ -129,7 +133,11 @@ def get_origin(tp: Any):
     Returns:
         The origin of the class.
     """
-    return get_origin_og(tp)
+    return (
+        origin
+        if (origin := getattr(tp, "__origin__", None)) is not None
+        else _get_origin_cached(tp)
+    )
 
 
 @lru_cache
@@ -190,7 +198,6 @@ def is_none(cls: GenericType) -> bool:
     return cls is type(None) or cls is None
 
 
-@lru_cache
 def is_union(cls: GenericType) -> bool:
     """Check if a class is a Union.
 
@@ -200,10 +207,12 @@ def is_union(cls: GenericType) -> bool:
     Returns:
         Whether the class is a Union.
     """
-    return get_origin(cls) in UnionTypes
+    origin = getattr(cls, "__origin__", None)
+    if origin is Union:
+        return True
+    return origin is None and isinstance(cls, types.UnionType)
 
 
-@lru_cache
 def is_literal(cls: GenericType) -> bool:
     """Check if a class is a Literal.
 
@@ -213,7 +222,7 @@ def is_literal(cls: GenericType) -> bool:
     Returns:
         Whether the class is a literal.
     """
-    return get_origin(cls) is Literal
+    return getattr(cls, "__origin__", None) is Literal
 
 
 def has_args(cls: type) -> bool: