types.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195
  1. """Contains custom types and methods to check types."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import inspect
  5. import types
  6. from collections.abc import Callable, Iterable, Mapping, Sequence
  7. from functools import cached_property, lru_cache, wraps
  8. from types import GenericAlias
  9. from typing import ( # noqa: UP035
  10. TYPE_CHECKING,
  11. Any,
  12. Awaitable,
  13. ClassVar,
  14. Dict,
  15. ForwardRef,
  16. List,
  17. Literal,
  18. MutableMapping,
  19. NoReturn,
  20. Protocol,
  21. Tuple,
  22. TypeVar,
  23. Union,
  24. _GenericAlias, # pyright: ignore [reportAttributeAccessIssue]
  25. _SpecialGenericAlias, # pyright: ignore [reportAttributeAccessIssue]
  26. get_args,
  27. is_typeddict,
  28. )
  29. from typing import get_origin as get_origin_og
  30. from typing import get_type_hints as get_type_hints_og
  31. from pydantic.v1.fields import ModelField
  32. from typing_extensions import Self as Self
  33. from typing_extensions import override as override
  34. import reflex
  35. from reflex import constants
  36. from reflex.base import Base
  37. from reflex.components.core.breakpoints import Breakpoints
  38. from reflex.utils import console
  39. # Potential GenericAlias types for isinstance checks.
  40. GenericAliasTypes = (_GenericAlias, GenericAlias, _SpecialGenericAlias)
  41. # Potential Union types for isinstance checks.
  42. UnionTypes = (Union, types.UnionType)
  43. # Union of generic types.
  44. GenericType = type | _GenericAlias
  45. # Valid state var types.
  46. JSONType = {str, int, float, bool}
  47. PrimitiveType = int | float | bool | str | list | dict | set | tuple
  48. PrimitiveTypes = (int, float, bool, str, list, dict, set, tuple)
  49. StateVar = PrimitiveType | Base | None
  50. StateIterVar = list | set | tuple
  51. if TYPE_CHECKING:
  52. from reflex.vars.base import Var
  53. VAR1 = TypeVar("VAR1", bound="Var")
  54. VAR2 = TypeVar("VAR2", bound="Var")
  55. VAR3 = TypeVar("VAR3", bound="Var")
  56. VAR4 = TypeVar("VAR4", bound="Var")
  57. VAR5 = TypeVar("VAR5", bound="Var")
  58. VAR6 = TypeVar("VAR6", bound="Var")
  59. VAR7 = TypeVar("VAR7", bound="Var")
  60. class _ArgsSpec0(Protocol):
  61. def __call__(self) -> Sequence[Var]: ...
  62. class _ArgsSpec1(Protocol):
  63. def __call__(self, var1: VAR1, /) -> Sequence[Var]: ... # pyright: ignore [reportInvalidTypeVarUse]
  64. class _ArgsSpec2(Protocol):
  65. def __call__(self, var1: VAR1, var2: VAR2, /) -> Sequence[Var]: ... # pyright: ignore [reportInvalidTypeVarUse]
  66. class _ArgsSpec3(Protocol):
  67. def __call__(self, var1: VAR1, var2: VAR2, var3: VAR3, /) -> Sequence[Var]: ... # pyright: ignore [reportInvalidTypeVarUse]
  68. class _ArgsSpec4(Protocol):
  69. def __call__(
  70. self,
  71. var1: VAR1, # pyright: ignore [reportInvalidTypeVarUse]
  72. var2: VAR2, # pyright: ignore [reportInvalidTypeVarUse]
  73. var3: VAR3, # pyright: ignore [reportInvalidTypeVarUse]
  74. var4: VAR4, # pyright: ignore [reportInvalidTypeVarUse]
  75. /,
  76. ) -> Sequence[Var]: ...
  77. class _ArgsSpec5(Protocol):
  78. def __call__(
  79. self,
  80. var1: VAR1, # pyright: ignore [reportInvalidTypeVarUse]
  81. var2: VAR2, # pyright: ignore [reportInvalidTypeVarUse]
  82. var3: VAR3, # pyright: ignore [reportInvalidTypeVarUse]
  83. var4: VAR4, # pyright: ignore [reportInvalidTypeVarUse]
  84. var5: VAR5, # pyright: ignore [reportInvalidTypeVarUse]
  85. /,
  86. ) -> Sequence[Var]: ...
  87. class _ArgsSpec6(Protocol):
  88. def __call__(
  89. self,
  90. var1: VAR1, # pyright: ignore [reportInvalidTypeVarUse]
  91. var2: VAR2, # pyright: ignore [reportInvalidTypeVarUse]
  92. var3: VAR3, # pyright: ignore [reportInvalidTypeVarUse]
  93. var4: VAR4, # pyright: ignore [reportInvalidTypeVarUse]
  94. var5: VAR5, # pyright: ignore [reportInvalidTypeVarUse]
  95. var6: VAR6, # pyright: ignore [reportInvalidTypeVarUse]
  96. /,
  97. ) -> Sequence[Var]: ...
  98. class _ArgsSpec7(Protocol):
  99. def __call__(
  100. self,
  101. var1: VAR1, # pyright: ignore [reportInvalidTypeVarUse]
  102. var2: VAR2, # pyright: ignore [reportInvalidTypeVarUse]
  103. var3: VAR3, # pyright: ignore [reportInvalidTypeVarUse]
  104. var4: VAR4, # pyright: ignore [reportInvalidTypeVarUse]
  105. var5: VAR5, # pyright: ignore [reportInvalidTypeVarUse]
  106. var6: VAR6, # pyright: ignore [reportInvalidTypeVarUse]
  107. var7: VAR7, # pyright: ignore [reportInvalidTypeVarUse]
  108. /,
  109. ) -> Sequence[Var]: ...
  110. ArgsSpec = (
  111. _ArgsSpec0
  112. | _ArgsSpec1
  113. | _ArgsSpec2
  114. | _ArgsSpec3
  115. | _ArgsSpec4
  116. | _ArgsSpec5
  117. | _ArgsSpec6
  118. | _ArgsSpec7
  119. )
  120. Scope = MutableMapping[str, Any]
  121. Message = MutableMapping[str, Any]
  122. Receive = Callable[[], Awaitable[Message]]
  123. Send = Callable[[Message], Awaitable[None]]
  124. ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
  125. PrimitiveToAnnotation = {
  126. list: List, # noqa: UP006
  127. tuple: Tuple, # noqa: UP006
  128. dict: Dict, # noqa: UP006
  129. }
  130. RESERVED_BACKEND_VAR_NAMES = {
  131. "_abc_impl",
  132. "_backend_vars",
  133. "_was_touched",
  134. }
  135. class Unset:
  136. """A class to represent an unset value.
  137. This is used to differentiate between a value that is not set and a value that is set to None.
  138. """
  139. def __repr__(self) -> str:
  140. """Return the string representation of the class.
  141. Returns:
  142. The string representation of the class.
  143. """
  144. return "Unset"
  145. def __bool__(self) -> bool:
  146. """Return False when the class is used in a boolean context.
  147. Returns:
  148. False
  149. """
  150. return False
  151. @lru_cache
  152. def _get_origin_cached(tp: Any):
  153. return get_origin_og(tp)
  154. def get_origin(tp: Any):
  155. """Get the origin of a class.
  156. Args:
  157. tp: The class to get the origin of.
  158. Returns:
  159. The origin of the class.
  160. """
  161. return (
  162. origin
  163. if (origin := getattr(tp, "__origin__", None)) is not None
  164. else _get_origin_cached(tp)
  165. )
  166. @lru_cache
  167. def is_generic_alias(cls: GenericType) -> bool:
  168. """Check whether the class is a generic alias.
  169. Args:
  170. cls: The class to check.
  171. Returns:
  172. Whether the class is a generic alias.
  173. """
  174. return isinstance(cls, GenericAliasTypes)
  175. @lru_cache
  176. def get_type_hints(obj: Any) -> dict[str, Any]:
  177. """Get the type hints of a class.
  178. Args:
  179. obj: The class to get the type hints of.
  180. Returns:
  181. The type hints of the class.
  182. """
  183. return get_type_hints_og(obj)
  184. def _unionize(args: list[GenericType]) -> GenericType:
  185. if not args:
  186. return Any # pyright: ignore [reportReturnType]
  187. if len(args) == 1:
  188. return args[0]
  189. return Union[tuple(args)] # noqa: UP007
  190. def unionize(*args: GenericType) -> type:
  191. """Unionize the types.
  192. Args:
  193. args: The types to unionize.
  194. Returns:
  195. The unionized types.
  196. """
  197. return _unionize([arg for arg in args if arg is not NoReturn])
  198. def is_none(cls: GenericType) -> bool:
  199. """Check if a class is None.
  200. Args:
  201. cls: The class to check.
  202. Returns:
  203. Whether the class is None.
  204. """
  205. return cls is type(None) or cls is None
  206. def is_union(cls: GenericType) -> bool:
  207. """Check if a class is a Union.
  208. Args:
  209. cls: The class to check.
  210. Returns:
  211. Whether the class is a Union.
  212. """
  213. origin = getattr(cls, "__origin__", None)
  214. if origin is Union:
  215. return True
  216. return origin is None and isinstance(cls, types.UnionType)
  217. def is_literal(cls: GenericType) -> bool:
  218. """Check if a class is a Literal.
  219. Args:
  220. cls: The class to check.
  221. Returns:
  222. Whether the class is a literal.
  223. """
  224. return getattr(cls, "__origin__", None) is Literal
  225. def has_args(cls: type) -> bool:
  226. """Check if the class has generic parameters.
  227. Args:
  228. cls: The class to check.
  229. Returns:
  230. Whether the class has generic
  231. """
  232. if get_args(cls):
  233. return True
  234. # Check if the class inherits from a generic class (using __orig_bases__)
  235. if hasattr(cls, "__orig_bases__"):
  236. for base in cls.__orig_bases__:
  237. if get_args(base):
  238. return True
  239. return False
  240. def is_optional(cls: GenericType) -> bool:
  241. """Check if a class is an Optional.
  242. Args:
  243. cls: The class to check.
  244. Returns:
  245. Whether the class is an Optional.
  246. """
  247. return is_union(cls) and type(None) in get_args(cls)
  248. def is_classvar(a_type: Any) -> bool:
  249. """Check if a type is a ClassVar.
  250. Args:
  251. a_type: The type to check.
  252. Returns:
  253. Whether the type is a ClassVar.
  254. """
  255. return (
  256. a_type is ClassVar
  257. or (type(a_type) is _GenericAlias and a_type.__origin__ is ClassVar)
  258. or (
  259. type(a_type) is ForwardRef and a_type.__forward_arg__.startswith("ClassVar")
  260. )
  261. )
  262. def true_type_for_pydantic_field(f: ModelField):
  263. """Get the type for a pydantic field.
  264. Args:
  265. f: The field to get the type for.
  266. Returns:
  267. The type for the field.
  268. """
  269. if not isinstance(f.annotation, (str, ForwardRef)):
  270. return f.annotation
  271. type_ = f.outer_type_
  272. if (
  273. f.field_info.default is None
  274. or (isinstance(f.annotation, str) and f.annotation.startswith("Optional"))
  275. or (
  276. isinstance(f.annotation, ForwardRef)
  277. and f.annotation.__forward_arg__.startswith("Optional")
  278. )
  279. ) and not is_optional(type_):
  280. return type_ | None
  281. return type_
  282. def value_inside_optional(cls: GenericType) -> GenericType:
  283. """Get the value inside an Optional type or the original type.
  284. Args:
  285. cls: The class to check.
  286. Returns:
  287. The value inside the Optional type or the original type.
  288. """
  289. if is_union(cls) and len(args := get_args(cls)) >= 2 and type(None) in args:
  290. if len(args) == 2:
  291. return args[0] if args[1] is type(None) else args[1]
  292. return unionize(*[arg for arg in args if arg is not type(None)])
  293. return cls
  294. def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
  295. """Get the type of a field in a class.
  296. Args:
  297. cls: The class to check.
  298. field_name: The name of the field to check.
  299. Returns:
  300. The type of the field, if it exists, else None.
  301. """
  302. if (
  303. hasattr(cls, "__fields__")
  304. and field_name in cls.__fields__
  305. and hasattr(cls.__fields__[field_name], "annotation")
  306. and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
  307. ):
  308. return cls.__fields__[field_name].annotation
  309. type_hints = get_type_hints(cls)
  310. return type_hints.get(field_name, None)
  311. def get_property_hint(attr: Any | None) -> GenericType | None:
  312. """Check if an attribute is a property and return its type hint.
  313. Args:
  314. attr: The descriptor to check.
  315. Returns:
  316. The type hint of the property, if it is a property, else None.
  317. """
  318. from sqlalchemy.ext.hybrid import hybrid_property
  319. if not isinstance(attr, (property, hybrid_property)):
  320. return None
  321. hints = get_type_hints(attr.fget)
  322. return hints.get("return", None)
  323. def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
  324. """Check if an attribute can be accessed on the cls and return its type.
  325. Supports pydantic models, unions, and annotated attributes on rx.Model.
  326. Args:
  327. cls: The class to check.
  328. name: The name of the attribute to check.
  329. Returns:
  330. The type of the attribute, if accessible, or None
  331. """
  332. import sqlalchemy
  333. from sqlalchemy.ext.associationproxy import AssociationProxyInstance
  334. from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
  335. from reflex.model import Model
  336. try:
  337. attr = getattr(cls, name, None)
  338. except NotImplementedError:
  339. attr = None
  340. if hint := get_property_hint(attr):
  341. return hint
  342. if hasattr(cls, "__fields__") and name in cls.__fields__:
  343. # pydantic models
  344. return get_field_type(cls, name)
  345. if isinstance(cls, type) and issubclass(cls, DeclarativeBase):
  346. insp = sqlalchemy.inspect(cls)
  347. if name in insp.columns:
  348. # check for list types
  349. column = insp.columns[name]
  350. column_type = column.type
  351. try:
  352. type_ = insp.columns[name].type.python_type
  353. except NotImplementedError:
  354. type_ = None
  355. if type_ is not None:
  356. if hasattr(column_type, "item_type"):
  357. try:
  358. item_type = column_type.item_type.python_type # pyright: ignore [reportAttributeAccessIssue]
  359. except NotImplementedError:
  360. item_type = None
  361. if item_type is not None:
  362. if type_ in PrimitiveToAnnotation:
  363. type_ = PrimitiveToAnnotation[type_]
  364. type_ = type_[item_type] # pyright: ignore [reportIndexIssue]
  365. if column.nullable:
  366. type_ = type_ | None
  367. return type_
  368. if name in insp.all_orm_descriptors:
  369. descriptor = insp.all_orm_descriptors[name]
  370. if hint := get_property_hint(descriptor):
  371. return hint
  372. if isinstance(descriptor, QueryableAttribute):
  373. prop = descriptor.property
  374. if isinstance(prop, Relationship):
  375. type_ = prop.mapper.class_
  376. # TODO: check for nullable?
  377. return list[type_] if prop.uselist else type_ | None
  378. if isinstance(attr, AssociationProxyInstance):
  379. return list[
  380. get_attribute_access_type(
  381. attr.target_class,
  382. attr.remote_attr.key, # type: ignore[attr-defined]
  383. )
  384. ]
  385. elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
  386. # Check in the annotations directly (for sqlmodel.Relationship)
  387. hints = get_type_hints(cls)
  388. if name in hints:
  389. type_ = hints[name]
  390. type_origin = get_origin(type_)
  391. if isinstance(type_origin, type) and issubclass(type_origin, Mapped):
  392. return get_args(type_)[0] # SQLAlchemy v2
  393. if isinstance(type_, ModelField):
  394. return type_.type_ # SQLAlchemy v1.4
  395. return type_
  396. elif is_union(cls):
  397. # Check in each arg of the annotation.
  398. return unionize(
  399. *(get_attribute_access_type(arg, name) for arg in get_args(cls))
  400. )
  401. elif isinstance(cls, type):
  402. # Bare class
  403. exceptions = NameError
  404. try:
  405. hints = get_type_hints(cls)
  406. if name in hints:
  407. return hints[name]
  408. except exceptions as e:
  409. console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}")
  410. return None # Attribute is not accessible.
  411. @lru_cache
  412. def get_base_class(cls: GenericType) -> type:
  413. """Get the base class of a class.
  414. Args:
  415. cls: The class.
  416. Returns:
  417. The base class of the class.
  418. Raises:
  419. TypeError: If a literal has multiple types.
  420. """
  421. if is_literal(cls):
  422. # only literals of the same type are supported.
  423. arg_type = type(get_args(cls)[0])
  424. if not all(type(arg) is arg_type for arg in get_args(cls)):
  425. msg = "only literals of the same type are supported"
  426. raise TypeError(msg)
  427. return type(get_args(cls)[0])
  428. if is_union(cls):
  429. return tuple(get_base_class(arg) for arg in get_args(cls)) # pyright: ignore [reportReturnType]
  430. return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls
  431. def _breakpoints_satisfies_typing(cls_check: GenericType, instance: Any) -> bool:
  432. """Check if the breakpoints instance satisfies the typing.
  433. Args:
  434. cls_check: The class to check against.
  435. instance: The instance to check.
  436. Returns:
  437. Whether the breakpoints instance satisfies the typing.
  438. """
  439. cls_check_base = get_base_class(cls_check)
  440. if cls_check_base == Breakpoints:
  441. _, expected_type = get_args(cls_check)
  442. if is_literal(expected_type):
  443. for value in instance.values():
  444. if not isinstance(value, str) or value not in get_args(expected_type):
  445. return False
  446. return True
  447. if isinstance(cls_check_base, tuple):
  448. # union type, so check all types
  449. return any(
  450. _breakpoints_satisfies_typing(type_to_check, instance)
  451. for type_to_check in get_args(cls_check)
  452. )
  453. if cls_check_base == reflex.vars.Var and "__args__" in cls_check.__dict__:
  454. return _breakpoints_satisfies_typing(get_args(cls_check)[0], instance)
  455. return False
  456. def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None) -> bool:
  457. """Check if a class is a subclass of another class.
  458. Args:
  459. cls: The class to check.
  460. cls_check: The class to check against.
  461. instance: An instance of cls to aid in checking generics.
  462. Returns:
  463. Whether the class is a subclass of the other class.
  464. Raises:
  465. TypeError: If the base class is not valid for issubclass.
  466. """
  467. # Special check for Any.
  468. if cls_check == Any:
  469. return True
  470. if cls in [Any, Callable, None]:
  471. return False
  472. # Get the base classes.
  473. cls_base = get_base_class(cls)
  474. cls_check_base = get_base_class(cls_check)
  475. # The class we're checking should not be a union.
  476. if isinstance(cls_base, tuple):
  477. return False
  478. # Check that fields of breakpoints match the expected values.
  479. if isinstance(instance, Breakpoints):
  480. return _breakpoints_satisfies_typing(cls_check, instance)
  481. if isinstance(cls_check_base, tuple):
  482. cls_check_base = tuple(
  483. cls_check_one if not is_typeddict(cls_check_one) else dict
  484. for cls_check_one in cls_check_base
  485. )
  486. if is_typeddict(cls_check_base):
  487. cls_check_base = dict
  488. # Check if the types match.
  489. try:
  490. return cls_check_base == Any or issubclass(cls_base, cls_check_base)
  491. except TypeError as te:
  492. # These errors typically arise from bad annotations and are hard to
  493. # debug without knowing the type that we tried to compare.
  494. msg = f"Invalid type for issubclass: {cls_base}"
  495. raise TypeError(msg) from te
  496. def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
  497. """Check if an object satisfies a typed dict.
  498. Args:
  499. obj: The object to check.
  500. cls: The typed dict to check against.
  501. Returns:
  502. Whether the object satisfies the typed dict.
  503. """
  504. if not isinstance(obj, Mapping):
  505. return False
  506. key_names_to_values = get_type_hints(cls)
  507. required_keys: frozenset[str] = getattr(cls, "__required_keys__", frozenset())
  508. if not all(
  509. isinstance(key, str)
  510. and key in key_names_to_values
  511. and _isinstance(value, key_names_to_values[key])
  512. for key, value in obj.items()
  513. ):
  514. return False
  515. # TODO in 3.14: Implement https://peps.python.org/pep-0728/ if it's approved
  516. # required keys are all present
  517. return required_keys.issubset(required_keys)
  518. def _isinstance(
  519. obj: Any,
  520. cls: GenericType,
  521. *,
  522. nested: int = 0,
  523. treat_var_as_type: bool = True,
  524. treat_mutable_obj_as_immutable: bool = False,
  525. ) -> bool:
  526. """Check if an object is an instance of a class.
  527. Args:
  528. obj: The object to check.
  529. cls: The class to check against.
  530. nested: How many levels deep to check.
  531. treat_var_as_type: Whether to treat Var as the type it represents, i.e. _var_type.
  532. treat_mutable_obj_as_immutable: Whether to treat mutable objects as immutable. Useful if a component declares a mutable object as a prop, but the value is not expected to change.
  533. Returns:
  534. Whether the object is an instance of the class.
  535. """
  536. if cls is Any:
  537. return True
  538. from reflex.vars import LiteralVar, Var
  539. if cls is Var:
  540. return isinstance(obj, Var)
  541. if isinstance(obj, LiteralVar):
  542. return treat_var_as_type and _isinstance(
  543. obj._var_value, cls, nested=nested, treat_var_as_type=True
  544. )
  545. if isinstance(obj, Var):
  546. return treat_var_as_type and typehint_issubclass(
  547. obj._var_type,
  548. cls,
  549. treat_mutable_superclasss_as_immutable=treat_mutable_obj_as_immutable,
  550. treat_literals_as_union_of_types=True,
  551. treat_any_as_subtype_of_everything=True,
  552. )
  553. if cls is None or cls is type(None):
  554. return obj is None
  555. if cls is not None and is_union(cls):
  556. return any(
  557. _isinstance(obj, arg, nested=nested, treat_var_as_type=treat_var_as_type)
  558. for arg in get_args(cls)
  559. )
  560. if is_literal(cls):
  561. return obj in get_args(cls)
  562. origin = get_origin(cls)
  563. if origin is None:
  564. # cls is a typed dict
  565. if is_typeddict(cls):
  566. if nested:
  567. return does_obj_satisfy_typed_dict(obj, cls)
  568. return isinstance(obj, dict)
  569. # cls is a float
  570. if cls is float:
  571. return isinstance(obj, (float, int))
  572. # cls is a simple class
  573. return isinstance(obj, cls)
  574. args = get_args(cls)
  575. if not args:
  576. if treat_mutable_obj_as_immutable:
  577. if origin is dict:
  578. origin = Mapping
  579. elif origin is list or origin is set:
  580. origin = Sequence
  581. # cls is a simple generic class
  582. return isinstance(obj, origin)
  583. if origin is Var and args:
  584. # cls is a Var
  585. return _isinstance(
  586. obj,
  587. args[0],
  588. nested=nested,
  589. treat_var_as_type=treat_var_as_type,
  590. treat_mutable_obj_as_immutable=treat_mutable_obj_as_immutable,
  591. )
  592. if nested > 0 and args:
  593. if origin is list:
  594. expected_class = Sequence if treat_mutable_obj_as_immutable else list
  595. return isinstance(obj, expected_class) and all(
  596. _isinstance(
  597. item,
  598. args[0],
  599. nested=nested - 1,
  600. treat_var_as_type=treat_var_as_type,
  601. )
  602. for item in obj
  603. )
  604. if origin is tuple:
  605. if args[-1] is Ellipsis:
  606. return isinstance(obj, tuple) and all(
  607. _isinstance(
  608. item,
  609. args[0],
  610. nested=nested - 1,
  611. treat_var_as_type=treat_var_as_type,
  612. )
  613. for item in obj
  614. )
  615. return (
  616. isinstance(obj, tuple)
  617. and len(obj) == len(args)
  618. and all(
  619. _isinstance(
  620. item,
  621. arg,
  622. nested=nested - 1,
  623. treat_var_as_type=treat_var_as_type,
  624. )
  625. for item, arg in zip(obj, args, strict=True)
  626. )
  627. )
  628. if origin in (dict, Mapping, Breakpoints):
  629. expected_class = (
  630. dict
  631. if origin is dict and not treat_mutable_obj_as_immutable
  632. else Mapping
  633. )
  634. return isinstance(obj, expected_class) and all(
  635. _isinstance(
  636. key, args[0], nested=nested - 1, treat_var_as_type=treat_var_as_type
  637. )
  638. and _isinstance(
  639. value,
  640. args[1],
  641. nested=nested - 1,
  642. treat_var_as_type=treat_var_as_type,
  643. )
  644. for key, value in obj.items()
  645. )
  646. if origin is set:
  647. expected_class = Sequence if treat_mutable_obj_as_immutable else set
  648. return isinstance(obj, expected_class) and all(
  649. _isinstance(
  650. item,
  651. args[0],
  652. nested=nested - 1,
  653. treat_var_as_type=treat_var_as_type,
  654. )
  655. for item in obj
  656. )
  657. if args:
  658. from reflex.vars import Field
  659. if origin is Field:
  660. return _isinstance(
  661. obj, args[0], nested=nested, treat_var_as_type=treat_var_as_type
  662. )
  663. return isinstance(obj, get_base_class(cls))
  664. def is_dataframe(value: type) -> bool:
  665. """Check if the given value is a dataframe.
  666. Args:
  667. value: The value to check.
  668. Returns:
  669. Whether the value is a dataframe.
  670. """
  671. if is_generic_alias(value) or value == Any:
  672. return False
  673. return value.__name__ == "DataFrame"
  674. def is_valid_var_type(type_: type) -> bool:
  675. """Check if the given type is a valid prop type.
  676. Args:
  677. type_: The type to check.
  678. Returns:
  679. Whether the type is a valid prop type.
  680. """
  681. from reflex.utils import serializers
  682. if is_union(type_):
  683. return all(is_valid_var_type(arg) for arg in get_args(type_))
  684. return (
  685. _issubclass(type_, StateVar)
  686. or serializers.has_serializer(type_)
  687. or dataclasses.is_dataclass(type_)
  688. )
  689. def is_backend_base_variable(name: str, cls: type) -> bool:
  690. """Check if this variable name correspond to a backend variable.
  691. Args:
  692. name: The name of the variable to check
  693. cls: The class of the variable to check
  694. Returns:
  695. bool: The result of the check
  696. """
  697. if name in RESERVED_BACKEND_VAR_NAMES:
  698. return False
  699. if not name.startswith("_"):
  700. return False
  701. if name.startswith("__"):
  702. return False
  703. if name.startswith(f"_{cls.__name__}__"):
  704. return False
  705. # Extract the namespace of the original module if defined (dynamic substates).
  706. if callable(getattr(cls, "_get_type_hints", None)):
  707. hints = cls._get_type_hints()
  708. else:
  709. hints = get_type_hints(cls)
  710. if name in hints:
  711. hint = get_origin(hints[name])
  712. if hint == ClassVar:
  713. return False
  714. if name in cls.inherited_backend_vars:
  715. return False
  716. from reflex.vars.base import is_computed_var
  717. if name in cls.__dict__:
  718. value = cls.__dict__[name]
  719. if type(value) is classmethod:
  720. return False
  721. if callable(value):
  722. return False
  723. if isinstance(
  724. value,
  725. (
  726. types.FunctionType,
  727. property,
  728. cached_property,
  729. ),
  730. ) or is_computed_var(value):
  731. return False
  732. return True
  733. def check_type_in_allowed_types(value_type: type, allowed_types: Iterable) -> bool:
  734. """Check that a value type is found in a list of allowed types.
  735. Args:
  736. value_type: Type of value.
  737. allowed_types: Iterable of allowed types.
  738. Returns:
  739. If the type is found in the allowed types.
  740. """
  741. return get_base_class(value_type) in allowed_types
  742. def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
  743. """Check that a prop value is in a list of allowed types.
  744. Does the check in a way that works regardless if it's a raw value or a state Var.
  745. Args:
  746. prop: The prop to check.
  747. allowed_types: The list of allowed types.
  748. Returns:
  749. If the prop type match one of the allowed_types.
  750. """
  751. from reflex.vars import Var
  752. type_ = prop._var_type if isinstance(prop, Var) else type(prop)
  753. return type_ in allowed_types
  754. def is_encoded_fstring(value: Any) -> bool:
  755. """Check if a value is an encoded Var f-string.
  756. Args:
  757. value: The value string to check.
  758. Returns:
  759. Whether the value is an f-string
  760. """
  761. return isinstance(value, str) and constants.REFLEX_VAR_OPENING_TAG in value
  762. def validate_literal(key: str, value: Any, expected_type: type, comp_name: str):
  763. """Check that a value is a valid literal.
  764. Args:
  765. key: The prop name.
  766. value: The prop value to validate.
  767. expected_type: The expected type(literal type).
  768. comp_name: Name of the component.
  769. Raises:
  770. ValueError: When the value is not a valid literal.
  771. """
  772. from reflex.vars import Var
  773. if (
  774. is_literal(expected_type)
  775. and not isinstance(value, Var) # validating vars is not supported yet.
  776. and not is_encoded_fstring(value) # f-strings are not supported.
  777. and value not in expected_type.__args__
  778. ):
  779. allowed_values = expected_type.__args__
  780. if value not in allowed_values:
  781. allowed_value_str = ",".join(
  782. [str(v) if not isinstance(v, str) else f"'{v}'" for v in allowed_values]
  783. )
  784. value_str = f"'{value}'" if isinstance(value, str) else value
  785. msg = f"prop value for {key!s} of the `{comp_name}` component should be one of the following: {allowed_value_str}. Got {value_str} instead"
  786. raise ValueError(msg)
  787. def validate_parameter_literals(func: Callable):
  788. """Decorator to check that the arguments passed to a function
  789. correspond to the correct function parameter if it (the parameter)
  790. is a literal type.
  791. Args:
  792. func: The function to validate.
  793. Returns:
  794. The wrapper function.
  795. """
  796. console.deprecate(
  797. "validate_parameter_literals",
  798. reason="Use manual validation instead.",
  799. deprecation_version="0.7.11",
  800. removal_version="0.8.0",
  801. dedupe=True,
  802. )
  803. func_params = list(inspect.signature(func).parameters.items())
  804. annotations = {param[0]: param[1].annotation for param in func_params}
  805. @wraps(func)
  806. def wrapper(*args, **kwargs):
  807. # validate args
  808. for param, arg in zip(annotations, args, strict=False):
  809. if annotations[param] is inspect.Parameter.empty:
  810. continue
  811. validate_literal(param, arg, annotations[param], func.__name__)
  812. # validate kwargs.
  813. for key, value in kwargs.items():
  814. annotation = annotations.get(key)
  815. if not annotation or annotation is inspect.Parameter.empty:
  816. continue
  817. validate_literal(key, value, annotation, func.__name__)
  818. return func(*args, **kwargs)
  819. return wrapper
  820. # Store this here for performance.
  821. StateBases = get_base_class(StateVar)
  822. StateIterBases = get_base_class(StateIterVar)
  823. def safe_issubclass(cls: Any, cls_check: Any | tuple[Any, ...]):
  824. """Check if a class is a subclass of another class. Returns False if internal error occurs.
  825. Args:
  826. cls: The class to check.
  827. cls_check: The class to check against.
  828. Returns:
  829. Whether the class is a subclass of the other class.
  830. """
  831. try:
  832. return issubclass(cls, cls_check)
  833. except TypeError:
  834. return False
  835. def typehint_issubclass(
  836. possible_subclass: Any,
  837. possible_superclass: Any,
  838. *,
  839. treat_mutable_superclasss_as_immutable: bool = False,
  840. treat_literals_as_union_of_types: bool = True,
  841. treat_any_as_subtype_of_everything: bool = False,
  842. ) -> bool:
  843. """Check if a type hint is a subclass of another type hint.
  844. Args:
  845. possible_subclass: The type hint to check.
  846. possible_superclass: The type hint to check against.
  847. treat_mutable_superclasss_as_immutable: Whether to treat target classes as immutable.
  848. treat_literals_as_union_of_types: Whether to treat literals as a union of their types.
  849. treat_any_as_subtype_of_everything: Whether to treat Any as a subtype of everything. This is the default behavior in Python.
  850. Returns:
  851. Whether the type hint is a subclass of the other type hint.
  852. """
  853. if possible_subclass is possible_superclass or possible_superclass is Any:
  854. return True
  855. if possible_subclass is Any:
  856. return treat_any_as_subtype_of_everything
  857. if possible_subclass is NoReturn:
  858. return True
  859. provided_type_origin = get_origin(possible_subclass)
  860. accepted_type_origin = get_origin(possible_superclass)
  861. if provided_type_origin is None and accepted_type_origin is None:
  862. # In this case, we are dealing with a non-generic type, so we can use issubclass
  863. return issubclass(possible_subclass, possible_superclass)
  864. if treat_literals_as_union_of_types and is_literal(possible_superclass):
  865. args = get_args(possible_superclass)
  866. return any(
  867. typehint_issubclass(
  868. possible_subclass,
  869. type(arg),
  870. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  871. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  872. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  873. )
  874. for arg in args
  875. )
  876. if is_literal(possible_subclass):
  877. args = get_args(possible_subclass)
  878. return all(
  879. _isinstance(
  880. arg,
  881. possible_superclass,
  882. treat_mutable_obj_as_immutable=treat_mutable_superclasss_as_immutable,
  883. nested=2,
  884. )
  885. for arg in args
  886. )
  887. provided_type_origin = (
  888. Union if provided_type_origin is types.UnionType else provided_type_origin
  889. )
  890. accepted_type_origin = (
  891. Union if accepted_type_origin is types.UnionType else accepted_type_origin
  892. )
  893. # Get type arguments (e.g., [float, int] for dict[float, int])
  894. provided_args = get_args(possible_subclass)
  895. accepted_args = get_args(possible_superclass)
  896. if accepted_type_origin is Union:
  897. if provided_type_origin is not Union:
  898. return any(
  899. typehint_issubclass(
  900. possible_subclass,
  901. accepted_arg,
  902. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  903. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  904. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  905. )
  906. for accepted_arg in accepted_args
  907. )
  908. return all(
  909. any(
  910. typehint_issubclass(
  911. provided_arg,
  912. accepted_arg,
  913. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  914. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  915. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  916. )
  917. for accepted_arg in accepted_args
  918. )
  919. for provided_arg in provided_args
  920. )
  921. if provided_type_origin is Union:
  922. return all(
  923. typehint_issubclass(
  924. provided_arg,
  925. possible_superclass,
  926. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  927. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  928. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  929. )
  930. for provided_arg in provided_args
  931. )
  932. provided_type_origin = provided_type_origin or possible_subclass
  933. accepted_type_origin = accepted_type_origin or possible_superclass
  934. if treat_mutable_superclasss_as_immutable:
  935. if accepted_type_origin is dict:
  936. accepted_type_origin = Mapping
  937. elif accepted_type_origin is list or accepted_type_origin is set:
  938. accepted_type_origin = Sequence
  939. # Check if the origin of both types is the same (e.g., list for list[int])
  940. if not safe_issubclass(
  941. provided_type_origin or possible_subclass,
  942. accepted_type_origin or possible_superclass,
  943. ):
  944. return False
  945. # Ensure all specific types are compatible with accepted types
  946. # Note this is not necessarily correct, as it doesn't check against contravariance and covariance
  947. # It also ignores when the length of the arguments is different
  948. return all(
  949. typehint_issubclass(
  950. provided_arg,
  951. accepted_arg,
  952. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  953. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  954. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  955. )
  956. for provided_arg, accepted_arg in zip(
  957. provided_args, accepted_args, strict=False
  958. )
  959. if accepted_arg is not Any
  960. )