"""Contains custom types and methods to check types.""" from __future__ import annotations import contextlib from typing import Any, Callable, Tuple, Type, Union, _GenericAlias # type: ignore from pynecone.base import Base # Union of generic types. GenericType = Union[Type, _GenericAlias] # Valid state var types. PrimitiveType = Union[int, float, bool, str, list, dict, tuple] StateVar = Union[PrimitiveType, Base, None] def get_args(alias: _GenericAlias) -> Tuple[Type, ...]: """Get the arguments of a type alias. Args: alias: The type alias. Returns: The arguments of the type alias. """ return alias.__args__ def is_generic_alias(cls: GenericType) -> bool: """Check whether the class is a generic alias. Args: cls: The class to check. Returns: Whether the class is a generic alias. """ # For older versions of Python. if isinstance(cls, _GenericAlias): return True with contextlib.suppress(ImportError): from typing import _SpecialGenericAlias # type: ignore if isinstance(cls, _SpecialGenericAlias): return True # For newer versions of Python. try: from types import GenericAlias # type: ignore return isinstance(cls, GenericAlias) except ImportError: return False def is_union(cls: GenericType) -> bool: """Check if a class is a Union. Args: cls: The class to check. Returns: Whether the class is a Union. """ with contextlib.suppress(ImportError): from typing import _UnionGenericAlias # type: ignore return isinstance(cls, _UnionGenericAlias) return cls.__origin__ == Union if is_generic_alias(cls) else False def get_base_class(cls: GenericType) -> Type: """Get the base class of a class. Args: cls: The class. Returns: The base class of the class. """ if is_union(cls): return tuple(get_base_class(arg) for arg in get_args(cls)) return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls def _issubclass(cls: GenericType, cls_check: GenericType) -> bool: """Check if a class is a subclass of another class. Args: cls: The class to check. cls_check: The class to check against. Returns: Whether the class is a subclass of the other class. """ # Special check for Any. if cls_check == Any: return True if cls in [Any, Callable]: return False # Get the base classes. cls_base = get_base_class(cls) cls_check_base = get_base_class(cls_check) # The class we're checking should not be a union. if isinstance(cls_base, tuple): return False # Check if the types match. return cls_check_base == Any or issubclass(cls_base, cls_check_base) def _isinstance(obj: Any, cls: GenericType) -> bool: """Check if an object is an instance of a class. Args: obj: The object to check. cls: The class to check against. Returns: Whether the object is an instance of the class. """ return isinstance(obj, get_base_class(cls)) def is_dataframe(value: Type) -> bool: """Check if the given value is a dataframe. Args: value: The value to check. Returns: Whether the value is a dataframe. """ return value.__name__ == "DataFrame" def is_figure(value: Type) -> bool: """Check if the given value is a figure. Args: value: The value to check. Returns: Whether the value is a figure. """ return value.__name__ == "Figure" def is_valid_var_type(var: Type) -> bool: """Check if the given value is a valid prop type. Args: var: The value to check. Returns: Whether the value is a valid prop type. """ return _issubclass(var, StateVar) or is_dataframe(var) or is_figure(var) def is_backend_variable(name: str) -> bool: """Check if this variable name correspond to a backend variable. Args: name: The name of the variable to check Returns: bool: The result of the check """ return name.startswith("_") and not name.startswith("__") # Store this here for performance. StateBases = get_base_class(StateVar)