types.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080
  1. """Contains custom types and methods to check types."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import inspect
  5. import types
  6. from collections.abc import Callable, Iterable, Mapping, Sequence
  7. from functools import cached_property, lru_cache, wraps
  8. from types import GenericAlias
  9. from typing import ( # noqa: UP035
  10. TYPE_CHECKING,
  11. Any,
  12. ClassVar,
  13. Dict,
  14. ForwardRef,
  15. List,
  16. Literal,
  17. NoReturn,
  18. Tuple,
  19. Union,
  20. _GenericAlias, # pyright: ignore [reportAttributeAccessIssue]
  21. _SpecialGenericAlias, # pyright: ignore [reportAttributeAccessIssue]
  22. get_args,
  23. )
  24. from typing import get_origin as get_origin_og
  25. from typing import get_type_hints as get_type_hints_og
  26. import sqlalchemy
  27. from pydantic.v1.fields import ModelField
  28. from sqlalchemy.ext.associationproxy import AssociationProxyInstance
  29. from sqlalchemy.ext.hybrid import hybrid_property
  30. from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
  31. from typing_extensions import Self as Self
  32. from typing_extensions import is_typeddict
  33. from typing_extensions import override as override
  34. import reflex
  35. from reflex import constants
  36. from reflex.base import Base
  37. from reflex.components.core.breakpoints import Breakpoints
  38. from reflex.utils import console
  39. # Potential GenericAlias types for isinstance checks.
  40. GenericAliasTypes = (_GenericAlias, GenericAlias, _SpecialGenericAlias)
  41. # Potential Union types for isinstance checks.
  42. UnionTypes = (Union, types.UnionType)
  43. # Union of generic types.
  44. GenericType = type | _GenericAlias
  45. # Valid state var types.
  46. JSONType = {str, int, float, bool}
  47. PrimitiveType = int | float | bool | str | list | dict | set | tuple
  48. PrimitiveTypes = (int, float, bool, str, list, dict, set, tuple)
  49. StateVar = PrimitiveType | Base | None
  50. StateIterVar = list | set | tuple
  51. if TYPE_CHECKING:
  52. from reflex.vars.base import Var
  53. ArgsSpec = (
  54. Callable[[], Sequence[Var]]
  55. | Callable[[Var], Sequence[Var]]
  56. | Callable[[Var, Var], Sequence[Var]]
  57. | Callable[[Var, Var, Var], Sequence[Var]]
  58. | Callable[[Var, Var, Var, Var], Sequence[Var]]
  59. | Callable[[Var, Var, Var, Var, Var], Sequence[Var]]
  60. | Callable[[Var, Var, Var, Var, Var, Var], Sequence[Var]]
  61. | Callable[[Var, Var, Var, Var, Var, Var, Var], Sequence[Var]]
  62. )
  63. else:
  64. ArgsSpec = Callable[..., list[Any]]
  65. PrimitiveToAnnotation = {
  66. list: List, # noqa: UP006
  67. tuple: Tuple, # noqa: UP006
  68. dict: Dict, # noqa: UP006
  69. }
  70. RESERVED_BACKEND_VAR_NAMES = {
  71. "_abc_impl",
  72. "_backend_vars",
  73. "_was_touched",
  74. }
  75. class Unset:
  76. """A class to represent an unset value.
  77. This is used to differentiate between a value that is not set and a value that is set to None.
  78. """
  79. def __repr__(self) -> str:
  80. """Return the string representation of the class.
  81. Returns:
  82. The string representation of the class.
  83. """
  84. return "Unset"
  85. def __bool__(self) -> bool:
  86. """Return False when the class is used in a boolean context.
  87. Returns:
  88. False
  89. """
  90. return False
  91. @lru_cache
  92. def get_origin(tp: Any):
  93. """Get the origin of a class.
  94. Args:
  95. tp: The class to get the origin of.
  96. Returns:
  97. The origin of the class.
  98. """
  99. return get_origin_og(tp)
  100. @lru_cache
  101. def is_generic_alias(cls: GenericType) -> bool:
  102. """Check whether the class is a generic alias.
  103. Args:
  104. cls: The class to check.
  105. Returns:
  106. Whether the class is a generic alias.
  107. """
  108. return isinstance(cls, GenericAliasTypes)
  109. @lru_cache
  110. def get_type_hints(obj: Any) -> dict[str, Any]:
  111. """Get the type hints of a class.
  112. Args:
  113. obj: The class to get the type hints of.
  114. Returns:
  115. The type hints of the class.
  116. """
  117. return get_type_hints_og(obj)
  118. def _unionize(args: list[GenericType]) -> type:
  119. if not args:
  120. return Any # pyright: ignore [reportReturnType]
  121. if len(args) == 1:
  122. return args[0]
  123. # We are bisecting the args list here to avoid hitting the recursion limit
  124. # In Python versions >= 3.11, we can simply do `return Union[*args]`
  125. midpoint = len(args) // 2
  126. first_half, second_half = args[:midpoint], args[midpoint:]
  127. return Union[unionize(*first_half), unionize(*second_half)] # pyright: ignore [reportReturnType] # noqa: UP007
  128. def unionize(*args: GenericType) -> type:
  129. """Unionize the types.
  130. Args:
  131. args: The types to unionize.
  132. Returns:
  133. The unionized types.
  134. """
  135. return _unionize([arg for arg in args if arg is not NoReturn])
  136. def is_none(cls: GenericType) -> bool:
  137. """Check if a class is None.
  138. Args:
  139. cls: The class to check.
  140. Returns:
  141. Whether the class is None.
  142. """
  143. return cls is type(None) or cls is None
  144. @lru_cache
  145. def is_union(cls: GenericType) -> bool:
  146. """Check if a class is a Union.
  147. Args:
  148. cls: The class to check.
  149. Returns:
  150. Whether the class is a Union.
  151. """
  152. return get_origin(cls) in UnionTypes
  153. @lru_cache
  154. def is_literal(cls: GenericType) -> bool:
  155. """Check if a class is a Literal.
  156. Args:
  157. cls: The class to check.
  158. Returns:
  159. Whether the class is a literal.
  160. """
  161. return get_origin(cls) is Literal
  162. def has_args(cls: type) -> bool:
  163. """Check if the class has generic parameters.
  164. Args:
  165. cls: The class to check.
  166. Returns:
  167. Whether the class has generic
  168. """
  169. if get_args(cls):
  170. return True
  171. # Check if the class inherits from a generic class (using __orig_bases__)
  172. if hasattr(cls, "__orig_bases__"):
  173. for base in cls.__orig_bases__:
  174. if get_args(base):
  175. return True
  176. return False
  177. def is_optional(cls: GenericType) -> bool:
  178. """Check if a class is an Optional.
  179. Args:
  180. cls: The class to check.
  181. Returns:
  182. Whether the class is an Optional.
  183. """
  184. return is_union(cls) and type(None) in get_args(cls)
  185. def true_type_for_pydantic_field(f: ModelField):
  186. """Get the type for a pydantic field.
  187. Args:
  188. f: The field to get the type for.
  189. Returns:
  190. The type for the field.
  191. """
  192. if not isinstance(f.annotation, (str, ForwardRef)):
  193. return f.annotation
  194. type_ = f.outer_type_
  195. if (
  196. f.field_info.default is None
  197. or (isinstance(f.annotation, str) and f.annotation.startswith("Optional"))
  198. or (
  199. isinstance(f.annotation, ForwardRef)
  200. and f.annotation.__forward_arg__.startswith("Optional")
  201. )
  202. ) and not is_optional(type_):
  203. return type_ | None
  204. return type_
  205. def value_inside_optional(cls: GenericType) -> GenericType:
  206. """Get the value inside an Optional type or the original type.
  207. Args:
  208. cls: The class to check.
  209. Returns:
  210. The value inside the Optional type or the original type.
  211. """
  212. if is_union(cls) and len(args := get_args(cls)) >= 2 and type(None) in args:
  213. if len(args) == 2:
  214. return args[0] if args[1] is type(None) else args[1]
  215. return unionize(*[arg for arg in args if arg is not type(None)])
  216. return cls
  217. def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
  218. """Get the type of a field in a class.
  219. Args:
  220. cls: The class to check.
  221. field_name: The name of the field to check.
  222. Returns:
  223. The type of the field, if it exists, else None.
  224. """
  225. if (
  226. hasattr(cls, "__fields__")
  227. and field_name in cls.__fields__
  228. and hasattr(cls.__fields__[field_name], "annotation")
  229. and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
  230. ):
  231. return cls.__fields__[field_name].annotation
  232. type_hints = get_type_hints(cls)
  233. return type_hints.get(field_name, None)
  234. def get_property_hint(attr: Any | None) -> GenericType | None:
  235. """Check if an attribute is a property and return its type hint.
  236. Args:
  237. attr: The descriptor to check.
  238. Returns:
  239. The type hint of the property, if it is a property, else None.
  240. """
  241. if not isinstance(attr, (property, hybrid_property)):
  242. return None
  243. hints = get_type_hints(attr.fget)
  244. return hints.get("return", None)
  245. def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
  246. """Check if an attribute can be accessed on the cls and return its type.
  247. Supports pydantic models, unions, and annotated attributes on rx.Model.
  248. Args:
  249. cls: The class to check.
  250. name: The name of the attribute to check.
  251. Returns:
  252. The type of the attribute, if accessible, or None
  253. """
  254. from reflex.model import Model
  255. try:
  256. attr = getattr(cls, name, None)
  257. except NotImplementedError:
  258. attr = None
  259. if hint := get_property_hint(attr):
  260. return hint
  261. if hasattr(cls, "__fields__") and name in cls.__fields__:
  262. # pydantic models
  263. return get_field_type(cls, name)
  264. elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
  265. insp = sqlalchemy.inspect(cls)
  266. if name in insp.columns:
  267. # check for list types
  268. column = insp.columns[name]
  269. column_type = column.type
  270. try:
  271. type_ = insp.columns[name].type.python_type
  272. except NotImplementedError:
  273. type_ = None
  274. if type_ is not None:
  275. if hasattr(column_type, "item_type"):
  276. try:
  277. item_type = column_type.item_type.python_type # pyright: ignore [reportAttributeAccessIssue]
  278. except NotImplementedError:
  279. item_type = None
  280. if item_type is not None:
  281. if type_ in PrimitiveToAnnotation:
  282. type_ = PrimitiveToAnnotation[type_]
  283. type_ = type_[item_type] # pyright: ignore [reportIndexIssue]
  284. if column.nullable:
  285. type_ = type_ | None
  286. return type_
  287. if name in insp.all_orm_descriptors:
  288. descriptor = insp.all_orm_descriptors[name]
  289. if hint := get_property_hint(descriptor):
  290. return hint
  291. if isinstance(descriptor, QueryableAttribute):
  292. prop = descriptor.property
  293. if isinstance(prop, Relationship):
  294. type_ = prop.mapper.class_
  295. # TODO: check for nullable?
  296. type_ = list[type_] if prop.uselist else type_ | None
  297. return type_
  298. if isinstance(attr, AssociationProxyInstance):
  299. return list[
  300. get_attribute_access_type(
  301. attr.target_class,
  302. attr.remote_attr.key, # type: ignore[attr-defined]
  303. )
  304. ]
  305. elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
  306. # Check in the annotations directly (for sqlmodel.Relationship)
  307. hints = get_type_hints(cls)
  308. if name in hints:
  309. type_ = hints[name]
  310. type_origin = get_origin(type_)
  311. if isinstance(type_origin, type) and issubclass(type_origin, Mapped):
  312. return get_args(type_)[0] # SQLAlchemy v2
  313. if isinstance(type_, ModelField):
  314. return type_.type_ # SQLAlchemy v1.4
  315. return type_
  316. elif is_union(cls):
  317. # Check in each arg of the annotation.
  318. return unionize(
  319. *(get_attribute_access_type(arg, name) for arg in get_args(cls))
  320. )
  321. elif isinstance(cls, type):
  322. # Bare class
  323. exceptions = NameError
  324. try:
  325. hints = get_type_hints(cls)
  326. if name in hints:
  327. return hints[name]
  328. except exceptions as e:
  329. console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}")
  330. pass
  331. return None # Attribute is not accessible.
  332. @lru_cache
  333. def get_base_class(cls: GenericType) -> type:
  334. """Get the base class of a class.
  335. Args:
  336. cls: The class.
  337. Returns:
  338. The base class of the class.
  339. Raises:
  340. TypeError: If a literal has multiple types.
  341. """
  342. if is_literal(cls):
  343. # only literals of the same type are supported.
  344. arg_type = type(get_args(cls)[0])
  345. if not all(type(arg) is arg_type for arg in get_args(cls)):
  346. raise TypeError("only literals of the same type are supported")
  347. return type(get_args(cls)[0])
  348. if is_union(cls):
  349. return tuple(get_base_class(arg) for arg in get_args(cls)) # pyright: ignore [reportReturnType]
  350. return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls
  351. def _breakpoints_satisfies_typing(cls_check: GenericType, instance: Any) -> bool:
  352. """Check if the breakpoints instance satisfies the typing.
  353. Args:
  354. cls_check: The class to check against.
  355. instance: The instance to check.
  356. Returns:
  357. Whether the breakpoints instance satisfies the typing.
  358. """
  359. cls_check_base = get_base_class(cls_check)
  360. if cls_check_base == Breakpoints:
  361. _, expected_type = get_args(cls_check)
  362. if is_literal(expected_type):
  363. for value in instance.values():
  364. if not isinstance(value, str) or value not in get_args(expected_type):
  365. return False
  366. return True
  367. elif isinstance(cls_check_base, tuple):
  368. # union type, so check all types
  369. return any(
  370. _breakpoints_satisfies_typing(type_to_check, instance)
  371. for type_to_check in get_args(cls_check)
  372. )
  373. elif cls_check_base == reflex.vars.Var and "__args__" in cls_check.__dict__:
  374. return _breakpoints_satisfies_typing(get_args(cls_check)[0], instance)
  375. return False
  376. def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None) -> bool:
  377. """Check if a class is a subclass of another class.
  378. Args:
  379. cls: The class to check.
  380. cls_check: The class to check against.
  381. instance: An instance of cls to aid in checking generics.
  382. Returns:
  383. Whether the class is a subclass of the other class.
  384. Raises:
  385. TypeError: If the base class is not valid for issubclass.
  386. """
  387. # Special check for Any.
  388. if cls_check == Any:
  389. return True
  390. if cls in [Any, Callable, None]:
  391. return False
  392. # Get the base classes.
  393. cls_base = get_base_class(cls)
  394. cls_check_base = get_base_class(cls_check)
  395. # The class we're checking should not be a union.
  396. if isinstance(cls_base, tuple):
  397. return False
  398. # Check that fields of breakpoints match the expected values.
  399. if isinstance(instance, Breakpoints):
  400. return _breakpoints_satisfies_typing(cls_check, instance)
  401. if isinstance(cls_check_base, tuple):
  402. cls_check_base = tuple(
  403. cls_check_one if not is_typeddict(cls_check_one) else dict
  404. for cls_check_one in cls_check_base
  405. )
  406. if is_typeddict(cls_check_base):
  407. cls_check_base = dict
  408. # Check if the types match.
  409. try:
  410. return cls_check_base == Any or issubclass(cls_base, cls_check_base)
  411. except TypeError as te:
  412. # These errors typically arise from bad annotations and are hard to
  413. # debug without knowing the type that we tried to compare.
  414. raise TypeError(f"Invalid type for issubclass: {cls_base}") from te
  415. def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
  416. """Check if an object satisfies a typed dict.
  417. Args:
  418. obj: The object to check.
  419. cls: The typed dict to check against.
  420. Returns:
  421. Whether the object satisfies the typed dict.
  422. """
  423. if not isinstance(obj, Mapping):
  424. return False
  425. key_names_to_values = get_type_hints(cls)
  426. required_keys: frozenset[str] = getattr(cls, "__required_keys__", frozenset())
  427. if not all(
  428. isinstance(key, str)
  429. and key in key_names_to_values
  430. and _isinstance(value, key_names_to_values[key])
  431. for key, value in obj.items()
  432. ):
  433. return False
  434. # TODO in 3.14: Implement https://peps.python.org/pep-0728/ if it's approved
  435. # required keys are all present
  436. return required_keys.issubset(required_keys)
  437. def _isinstance(
  438. obj: Any,
  439. cls: GenericType,
  440. *,
  441. nested: int = 0,
  442. treat_var_as_type: bool = True,
  443. treat_mutable_obj_as_immutable: bool = False,
  444. ) -> bool:
  445. """Check if an object is an instance of a class.
  446. Args:
  447. obj: The object to check.
  448. cls: The class to check against.
  449. nested: How many levels deep to check.
  450. treat_var_as_type: Whether to treat Var as the type it represents, i.e. _var_type.
  451. treat_mutable_obj_as_immutable: Whether to treat mutable objects as immutable. Useful if a component declares a mutable object as a prop, but the value is not expected to change.
  452. Returns:
  453. Whether the object is an instance of the class.
  454. """
  455. if cls is Any:
  456. return True
  457. from reflex.vars import LiteralVar, Var
  458. if cls is Var:
  459. return isinstance(obj, Var)
  460. if isinstance(obj, LiteralVar):
  461. return treat_var_as_type and _isinstance(
  462. obj._var_value, cls, nested=nested, treat_var_as_type=True
  463. )
  464. if isinstance(obj, Var):
  465. return treat_var_as_type and typehint_issubclass(
  466. obj._var_type,
  467. cls,
  468. treat_mutable_superclasss_as_immutable=treat_mutable_obj_as_immutable,
  469. treat_literals_as_union_of_types=True,
  470. treat_any_as_subtype_of_everything=True,
  471. )
  472. if cls is None or cls is type(None):
  473. return obj is None
  474. if cls is not None and is_union(cls):
  475. return any(
  476. _isinstance(obj, arg, nested=nested, treat_var_as_type=treat_var_as_type)
  477. for arg in get_args(cls)
  478. )
  479. if is_literal(cls):
  480. return obj in get_args(cls)
  481. origin = get_origin(cls)
  482. if origin is None:
  483. # cls is a typed dict
  484. if is_typeddict(cls):
  485. if nested:
  486. return does_obj_satisfy_typed_dict(obj, cls)
  487. return isinstance(obj, dict)
  488. # cls is a float
  489. if cls is float:
  490. return isinstance(obj, (float, int))
  491. # cls is a simple class
  492. return isinstance(obj, cls)
  493. args = get_args(cls)
  494. if not args:
  495. if treat_mutable_obj_as_immutable:
  496. if origin is dict:
  497. origin = Mapping
  498. elif origin is list or origin is set:
  499. origin = Sequence
  500. # cls is a simple generic class
  501. return isinstance(obj, origin)
  502. if origin is Var and args:
  503. # cls is a Var
  504. return _isinstance(
  505. obj,
  506. args[0],
  507. nested=nested,
  508. treat_var_as_type=treat_var_as_type,
  509. treat_mutable_obj_as_immutable=treat_mutable_obj_as_immutable,
  510. )
  511. if nested > 0 and args:
  512. if origin is list:
  513. expected_class = Sequence if treat_mutable_obj_as_immutable else list
  514. return isinstance(obj, expected_class) and all(
  515. _isinstance(
  516. item,
  517. args[0],
  518. nested=nested - 1,
  519. treat_var_as_type=treat_var_as_type,
  520. )
  521. for item in obj
  522. )
  523. if origin is tuple:
  524. if args[-1] is Ellipsis:
  525. return isinstance(obj, tuple) and all(
  526. _isinstance(
  527. item,
  528. args[0],
  529. nested=nested - 1,
  530. treat_var_as_type=treat_var_as_type,
  531. )
  532. for item in obj
  533. )
  534. return (
  535. isinstance(obj, tuple)
  536. and len(obj) == len(args)
  537. and all(
  538. _isinstance(
  539. item,
  540. arg,
  541. nested=nested - 1,
  542. treat_var_as_type=treat_var_as_type,
  543. )
  544. for item, arg in zip(obj, args, strict=True)
  545. )
  546. )
  547. if origin in (dict, Mapping, Breakpoints):
  548. expected_class = (
  549. dict
  550. if origin is dict and not treat_mutable_obj_as_immutable
  551. else Mapping
  552. )
  553. return isinstance(obj, expected_class) and all(
  554. _isinstance(
  555. key, args[0], nested=nested - 1, treat_var_as_type=treat_var_as_type
  556. )
  557. and _isinstance(
  558. value,
  559. args[1],
  560. nested=nested - 1,
  561. treat_var_as_type=treat_var_as_type,
  562. )
  563. for key, value in obj.items()
  564. )
  565. if origin is set:
  566. expected_class = Sequence if treat_mutable_obj_as_immutable else set
  567. return isinstance(obj, expected_class) and all(
  568. _isinstance(
  569. item,
  570. args[0],
  571. nested=nested - 1,
  572. treat_var_as_type=treat_var_as_type,
  573. )
  574. for item in obj
  575. )
  576. if args:
  577. from reflex.vars import Field
  578. if origin is Field:
  579. return _isinstance(
  580. obj, args[0], nested=nested, treat_var_as_type=treat_var_as_type
  581. )
  582. return isinstance(obj, get_base_class(cls))
  583. def is_dataframe(value: type) -> bool:
  584. """Check if the given value is a dataframe.
  585. Args:
  586. value: The value to check.
  587. Returns:
  588. Whether the value is a dataframe.
  589. """
  590. if is_generic_alias(value) or value == Any:
  591. return False
  592. return value.__name__ == "DataFrame"
  593. def is_valid_var_type(type_: type) -> bool:
  594. """Check if the given type is a valid prop type.
  595. Args:
  596. type_: The type to check.
  597. Returns:
  598. Whether the type is a valid prop type.
  599. """
  600. from reflex.utils import serializers
  601. if is_union(type_):
  602. return all(is_valid_var_type(arg) for arg in get_args(type_))
  603. return (
  604. _issubclass(type_, StateVar)
  605. or serializers.has_serializer(type_)
  606. or dataclasses.is_dataclass(type_)
  607. )
  608. def is_backend_base_variable(name: str, cls: type) -> bool:
  609. """Check if this variable name correspond to a backend variable.
  610. Args:
  611. name: The name of the variable to check
  612. cls: The class of the variable to check
  613. Returns:
  614. bool: The result of the check
  615. """
  616. if name in RESERVED_BACKEND_VAR_NAMES:
  617. return False
  618. if not name.startswith("_"):
  619. return False
  620. if name.startswith("__"):
  621. return False
  622. if name.startswith(f"_{cls.__name__}__"):
  623. return False
  624. # Extract the namespace of the original module if defined (dynamic substates).
  625. if callable(getattr(cls, "_get_type_hints", None)):
  626. hints = cls._get_type_hints()
  627. else:
  628. hints = get_type_hints(cls)
  629. if name in hints:
  630. hint = get_origin(hints[name])
  631. if hint == ClassVar:
  632. return False
  633. if name in cls.inherited_backend_vars:
  634. return False
  635. from reflex.vars.base import is_computed_var
  636. if name in cls.__dict__:
  637. value = cls.__dict__[name]
  638. if type(value) is classmethod:
  639. return False
  640. if callable(value):
  641. return False
  642. if isinstance(
  643. value,
  644. (
  645. types.FunctionType,
  646. property,
  647. cached_property,
  648. ),
  649. ) or is_computed_var(value):
  650. return False
  651. return True
  652. def check_type_in_allowed_types(value_type: type, allowed_types: Iterable) -> bool:
  653. """Check that a value type is found in a list of allowed types.
  654. Args:
  655. value_type: Type of value.
  656. allowed_types: Iterable of allowed types.
  657. Returns:
  658. If the type is found in the allowed types.
  659. """
  660. return get_base_class(value_type) in allowed_types
  661. def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
  662. """Check that a prop value is in a list of allowed types.
  663. Does the check in a way that works regardless if it's a raw value or a state Var.
  664. Args:
  665. prop: The prop to check.
  666. allowed_types: The list of allowed types.
  667. Returns:
  668. If the prop type match one of the allowed_types.
  669. """
  670. from reflex.vars import Var
  671. type_ = prop._var_type if isinstance(prop, Var) else type(prop)
  672. return type_ in allowed_types
  673. def is_encoded_fstring(value: Any) -> bool:
  674. """Check if a value is an encoded Var f-string.
  675. Args:
  676. value: The value string to check.
  677. Returns:
  678. Whether the value is an f-string
  679. """
  680. return isinstance(value, str) and constants.REFLEX_VAR_OPENING_TAG in value
  681. def validate_literal(key: str, value: Any, expected_type: type, comp_name: str):
  682. """Check that a value is a valid literal.
  683. Args:
  684. key: The prop name.
  685. value: The prop value to validate.
  686. expected_type: The expected type(literal type).
  687. comp_name: Name of the component.
  688. Raises:
  689. ValueError: When the value is not a valid literal.
  690. """
  691. from reflex.vars import Var
  692. if (
  693. is_literal(expected_type)
  694. and not isinstance(value, Var) # validating vars is not supported yet.
  695. and not is_encoded_fstring(value) # f-strings are not supported.
  696. and value not in expected_type.__args__
  697. ):
  698. allowed_values = expected_type.__args__
  699. if value not in allowed_values:
  700. allowed_value_str = ",".join(
  701. [str(v) if not isinstance(v, str) else f"'{v}'" for v in allowed_values]
  702. )
  703. value_str = f"'{value}'" if isinstance(value, str) else value
  704. raise ValueError(
  705. f"prop value for {key!s} of the `{comp_name}` component should be one of the following: {allowed_value_str}. Got {value_str} instead"
  706. )
  707. def validate_parameter_literals(func: Callable):
  708. """Decorator to check that the arguments passed to a function
  709. correspond to the correct function parameter if it (the parameter)
  710. is a literal type.
  711. Args:
  712. func: The function to validate.
  713. Returns:
  714. The wrapper function.
  715. """
  716. @wraps(func)
  717. def wrapper(*args, **kwargs):
  718. func_params = list(inspect.signature(func).parameters.items())
  719. annotations = {param[0]: param[1].annotation for param in func_params}
  720. # validate args
  721. for param, arg in zip(annotations, args, strict=False):
  722. if annotations[param] is inspect.Parameter.empty:
  723. continue
  724. validate_literal(param, arg, annotations[param], func.__name__)
  725. # validate kwargs.
  726. for key, value in kwargs.items():
  727. annotation = annotations.get(key)
  728. if not annotation or annotation is inspect.Parameter.empty:
  729. continue
  730. validate_literal(key, value, annotation, func.__name__)
  731. return func(*args, **kwargs)
  732. return wrapper
  733. # Store this here for performance.
  734. StateBases = get_base_class(StateVar)
  735. StateIterBases = get_base_class(StateIterVar)
  736. def safe_issubclass(cls: Any, cls_check: Any | tuple[Any, ...]):
  737. """Check if a class is a subclass of another class. Returns False if internal error occurs.
  738. Args:
  739. cls: The class to check.
  740. cls_check: The class to check against.
  741. Returns:
  742. Whether the class is a subclass of the other class.
  743. """
  744. try:
  745. return issubclass(cls, cls_check)
  746. except TypeError:
  747. return False
  748. def typehint_issubclass(
  749. possible_subclass: Any,
  750. possible_superclass: Any,
  751. *,
  752. treat_mutable_superclasss_as_immutable: bool = False,
  753. treat_literals_as_union_of_types: bool = True,
  754. treat_any_as_subtype_of_everything: bool = False,
  755. ) -> bool:
  756. """Check if a type hint is a subclass of another type hint.
  757. Args:
  758. possible_subclass: The type hint to check.
  759. possible_superclass: The type hint to check against.
  760. treat_mutable_superclasss_as_immutable: Whether to treat target classes as immutable.
  761. treat_literals_as_union_of_types: Whether to treat literals as a union of their types.
  762. treat_any_as_subtype_of_everything: Whether to treat Any as a subtype of everything. This is the default behavior in Python.
  763. Returns:
  764. Whether the type hint is a subclass of the other type hint.
  765. """
  766. if possible_superclass is Any:
  767. return True
  768. if possible_subclass is Any:
  769. return treat_any_as_subtype_of_everything
  770. if possible_subclass is NoReturn:
  771. return True
  772. provided_type_origin = get_origin(possible_subclass)
  773. accepted_type_origin = get_origin(possible_superclass)
  774. if provided_type_origin is None and accepted_type_origin is None:
  775. # In this case, we are dealing with a non-generic type, so we can use issubclass
  776. return issubclass(possible_subclass, possible_superclass)
  777. if treat_literals_as_union_of_types and is_literal(possible_superclass):
  778. args = get_args(possible_superclass)
  779. return any(
  780. typehint_issubclass(
  781. possible_subclass,
  782. type(arg),
  783. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  784. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  785. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  786. )
  787. for arg in args
  788. )
  789. if is_literal(possible_subclass):
  790. args = get_args(possible_subclass)
  791. return all(
  792. _isinstance(
  793. arg,
  794. possible_superclass,
  795. treat_mutable_obj_as_immutable=treat_mutable_superclasss_as_immutable,
  796. nested=2,
  797. )
  798. for arg in args
  799. )
  800. provided_type_origin = (
  801. Union if provided_type_origin is types.UnionType else provided_type_origin
  802. )
  803. accepted_type_origin = (
  804. Union if accepted_type_origin is types.UnionType else accepted_type_origin
  805. )
  806. # Get type arguments (e.g., [float, int] for dict[float, int])
  807. provided_args = get_args(possible_subclass)
  808. accepted_args = get_args(possible_superclass)
  809. if accepted_type_origin is Union:
  810. if provided_type_origin is not Union:
  811. return any(
  812. typehint_issubclass(
  813. possible_subclass,
  814. accepted_arg,
  815. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  816. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  817. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  818. )
  819. for accepted_arg in accepted_args
  820. )
  821. return all(
  822. any(
  823. typehint_issubclass(
  824. provided_arg,
  825. accepted_arg,
  826. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  827. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  828. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  829. )
  830. for accepted_arg in accepted_args
  831. )
  832. for provided_arg in provided_args
  833. )
  834. if provided_type_origin is Union:
  835. return all(
  836. typehint_issubclass(
  837. provided_arg,
  838. possible_superclass,
  839. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  840. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  841. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  842. )
  843. for provided_arg in provided_args
  844. )
  845. provided_type_origin = provided_type_origin or possible_subclass
  846. accepted_type_origin = accepted_type_origin or possible_superclass
  847. if treat_mutable_superclasss_as_immutable:
  848. if accepted_type_origin is dict:
  849. accepted_type_origin = Mapping
  850. elif accepted_type_origin is list or accepted_type_origin is set:
  851. accepted_type_origin = Sequence
  852. # Check if the origin of both types is the same (e.g., list for list[int])
  853. if not safe_issubclass(
  854. provided_type_origin or possible_subclass,
  855. accepted_type_origin or possible_superclass,
  856. ):
  857. return False
  858. # Ensure all specific types are compatible with accepted types
  859. # Note this is not necessarily correct, as it doesn't check against contravariance and covariance
  860. # It also ignores when the length of the arguments is different
  861. return all(
  862. typehint_issubclass(
  863. provided_arg,
  864. accepted_arg,
  865. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  866. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  867. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  868. )
  869. for provided_arg, accepted_arg in zip(
  870. provided_args, accepted_args, strict=False
  871. )
  872. if accepted_arg is not Any
  873. )