"""Contains custom types and methods to check types.""" from __future__ import annotations import contextlib import typing from datetime import date, datetime, time, timedelta from typing import Any, Callable, Type, Union, _GenericAlias # type: ignore from reflex.base import Base # Union of generic types. GenericType = Union[Type, _GenericAlias] # Valid state var types. PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple] StateVar = Union[PrimitiveType, Base, None] StateIterVar = Union[list, set, tuple] def get_args(alias: _GenericAlias) -> tuple[Type, ...]: """Get the arguments of a type alias. Args: alias: The type alias. Returns: The arguments of the type alias. """ return alias.__args__ def is_generic_alias(cls: GenericType) -> bool: """Check whether the class is a generic alias. Args: cls: The class to check. Returns: Whether the class is a generic alias. """ # For older versions of Python. if isinstance(cls, _GenericAlias): return True with contextlib.suppress(ImportError): from typing import _SpecialGenericAlias # type: ignore if isinstance(cls, _SpecialGenericAlias): return True # For newer versions of Python. try: from types import GenericAlias # type: ignore return isinstance(cls, GenericAlias) except ImportError: return False def is_union(cls: GenericType) -> bool: """Check if a class is a Union. Args: cls: The class to check. Returns: Whether the class is a Union. """ with contextlib.suppress(ImportError): from typing import _UnionGenericAlias # type: ignore return isinstance(cls, _UnionGenericAlias) return cls.__origin__ == Union if is_generic_alias(cls) else False def get_base_class(cls: GenericType) -> Type: """Get the base class of a class. Args: cls: The class. Returns: The base class of the class. """ if is_union(cls): return tuple(get_base_class(arg) for arg in get_args(cls)) return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls def _issubclass(cls: GenericType, cls_check: GenericType) -> bool: """Check if a class is a subclass of another class. Args: cls: The class to check. cls_check: The class to check against. Returns: Whether the class is a subclass of the other class. """ # Special check for Any. if cls_check == Any: return True if cls in [Any, Callable]: return False # Get the base classes. cls_base = get_base_class(cls) cls_check_base = get_base_class(cls_check) # The class we're checking should not be a union. if isinstance(cls_base, tuple): return False # Check if the types match. return cls_check_base == Any or issubclass(cls_base, cls_check_base) def _isinstance(obj: Any, cls: GenericType) -> bool: """Check if an object is an instance of a class. Args: obj: The object to check. cls: The class to check against. Returns: Whether the object is an instance of the class. """ return isinstance(obj, get_base_class(cls)) def is_dataframe(value: Type) -> bool: """Check if the given value is a dataframe. Args: value: The value to check. Returns: Whether the value is a dataframe. """ if is_generic_alias(value) or value == typing.Any: return False return value.__name__ == "DataFrame" def is_image(value: Type) -> bool: """Check if the given value is a pillow image. By checking if the value subclasses PIL. Args: value: The value to check. Returns: Whether the value is a pillow image. """ if is_generic_alias(value) or value == typing.Any: return False return "PIL" in value.__module__ def is_figure(value: Type) -> bool: """Check if the given value is a figure. Args: value: The value to check. Returns: Whether the value is a figure. """ return value.__name__ == "Figure" def is_datetime(value: Type) -> bool: """Check if the given value is a datetime object. Args: value: The value to check. Returns: Whether the value is a date, datetime, time, or timedelta. """ return issubclass(value, (date, datetime, time, timedelta)) def is_valid_var_type(var: Type) -> bool: """Check if the given value is a valid prop type. Args: var: The value to check. Returns: Whether the value is a valid prop type. """ return ( _issubclass(var, StateVar) or is_dataframe(var) or is_figure(var) or is_image(var) or is_datetime(var) ) def is_backend_variable(name: str) -> bool: """Check if this variable name correspond to a backend variable. Args: name: The name of the variable to check Returns: bool: The result of the check """ return name.startswith("_") and not name.startswith("__") # Store this here for performance. StateBases = get_base_class(StateVar) StateIterBases = get_base_class(StateIterVar)