1
0

types.py 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103
  1. """Contains custom types and methods to check types."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import inspect
  5. import types
  6. from collections.abc import Callable, Iterable, Mapping, Sequence
  7. from functools import cached_property, lru_cache, wraps
  8. from types import GenericAlias
  9. from typing import ( # noqa: UP035
  10. TYPE_CHECKING,
  11. Any,
  12. Awaitable,
  13. ClassVar,
  14. Dict,
  15. ForwardRef,
  16. List,
  17. Literal,
  18. MutableMapping,
  19. NoReturn,
  20. Tuple,
  21. Union,
  22. _GenericAlias, # pyright: ignore [reportAttributeAccessIssue]
  23. _SpecialGenericAlias, # pyright: ignore [reportAttributeAccessIssue]
  24. get_args,
  25. is_typeddict,
  26. )
  27. from typing import get_origin as get_origin_og
  28. from typing import get_type_hints as get_type_hints_og
  29. from pydantic.v1.fields import ModelField
  30. from typing_extensions import Self as Self
  31. from typing_extensions import override as override
  32. import reflex
  33. from reflex import constants
  34. from reflex.base import Base
  35. from reflex.components.core.breakpoints import Breakpoints
  36. from reflex.utils import console
  37. # Potential GenericAlias types for isinstance checks.
  38. GenericAliasTypes = (_GenericAlias, GenericAlias, _SpecialGenericAlias)
  39. # Potential Union types for isinstance checks.
  40. UnionTypes = (Union, types.UnionType)
  41. # Union of generic types.
  42. GenericType = type | _GenericAlias
  43. # Valid state var types.
  44. JSONType = {str, int, float, bool}
  45. PrimitiveType = int | float | bool | str | list | dict | set | tuple
  46. PrimitiveTypes = (int, float, bool, str, list, dict, set, tuple)
  47. StateVar = PrimitiveType | Base | None
  48. StateIterVar = list | set | tuple
  49. if TYPE_CHECKING:
  50. from reflex.vars.base import Var
  51. ArgsSpec = (
  52. Callable[[], Sequence[Var]]
  53. | Callable[[Var], Sequence[Var]]
  54. | Callable[[Var, Var], Sequence[Var]]
  55. | Callable[[Var, Var, Var], Sequence[Var]]
  56. | Callable[[Var, Var, Var, Var], Sequence[Var]]
  57. | Callable[[Var, Var, Var, Var, Var], Sequence[Var]]
  58. | Callable[[Var, Var, Var, Var, Var, Var], Sequence[Var]]
  59. | Callable[[Var, Var, Var, Var, Var, Var, Var], Sequence[Var]]
  60. )
  61. else:
  62. ArgsSpec = Callable[..., list[Any]]
  63. Scope = MutableMapping[str, Any]
  64. Message = MutableMapping[str, Any]
  65. Receive = Callable[[], Awaitable[Message]]
  66. Send = Callable[[Message], Awaitable[None]]
  67. ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
  68. PrimitiveToAnnotation = {
  69. list: List, # noqa: UP006
  70. tuple: Tuple, # noqa: UP006
  71. dict: Dict, # noqa: UP006
  72. }
  73. RESERVED_BACKEND_VAR_NAMES = {
  74. "_abc_impl",
  75. "_backend_vars",
  76. "_was_touched",
  77. }
  78. class Unset:
  79. """A class to represent an unset value.
  80. This is used to differentiate between a value that is not set and a value that is set to None.
  81. """
  82. def __repr__(self) -> str:
  83. """Return the string representation of the class.
  84. Returns:
  85. The string representation of the class.
  86. """
  87. return "Unset"
  88. def __bool__(self) -> bool:
  89. """Return False when the class is used in a boolean context.
  90. Returns:
  91. False
  92. """
  93. return False
  94. @lru_cache
  95. def _get_origin_cached(tp: Any):
  96. return get_origin_og(tp)
  97. def get_origin(tp: Any):
  98. """Get the origin of a class.
  99. Args:
  100. tp: The class to get the origin of.
  101. Returns:
  102. The origin of the class.
  103. """
  104. return (
  105. origin
  106. if (origin := getattr(tp, "__origin__", None)) is not None
  107. else _get_origin_cached(tp)
  108. )
  109. @lru_cache
  110. def is_generic_alias(cls: GenericType) -> bool:
  111. """Check whether the class is a generic alias.
  112. Args:
  113. cls: The class to check.
  114. Returns:
  115. Whether the class is a generic alias.
  116. """
  117. return isinstance(cls, GenericAliasTypes)
  118. @lru_cache
  119. def get_type_hints(obj: Any) -> dict[str, Any]:
  120. """Get the type hints of a class.
  121. Args:
  122. obj: The class to get the type hints of.
  123. Returns:
  124. The type hints of the class.
  125. """
  126. return get_type_hints_og(obj)
  127. def _unionize(args: list[GenericType]) -> GenericType:
  128. if not args:
  129. return Any # pyright: ignore [reportReturnType]
  130. if len(args) == 1:
  131. return args[0]
  132. return Union[tuple(args)] # noqa: UP007
  133. def unionize(*args: GenericType) -> type:
  134. """Unionize the types.
  135. Args:
  136. args: The types to unionize.
  137. Returns:
  138. The unionized types.
  139. """
  140. return _unionize([arg for arg in args if arg is not NoReturn])
  141. def is_none(cls: GenericType) -> bool:
  142. """Check if a class is None.
  143. Args:
  144. cls: The class to check.
  145. Returns:
  146. Whether the class is None.
  147. """
  148. return cls is type(None) or cls is None
  149. def is_union(cls: GenericType) -> bool:
  150. """Check if a class is a Union.
  151. Args:
  152. cls: The class to check.
  153. Returns:
  154. Whether the class is a Union.
  155. """
  156. origin = getattr(cls, "__origin__", None)
  157. if origin is Union:
  158. return True
  159. return origin is None and isinstance(cls, types.UnionType)
  160. def is_literal(cls: GenericType) -> bool:
  161. """Check if a class is a Literal.
  162. Args:
  163. cls: The class to check.
  164. Returns:
  165. Whether the class is a literal.
  166. """
  167. return getattr(cls, "__origin__", None) is Literal
  168. def has_args(cls: type) -> bool:
  169. """Check if the class has generic parameters.
  170. Args:
  171. cls: The class to check.
  172. Returns:
  173. Whether the class has generic
  174. """
  175. if get_args(cls):
  176. return True
  177. # Check if the class inherits from a generic class (using __orig_bases__)
  178. if hasattr(cls, "__orig_bases__"):
  179. for base in cls.__orig_bases__:
  180. if get_args(base):
  181. return True
  182. return False
  183. def is_optional(cls: GenericType) -> bool:
  184. """Check if a class is an Optional.
  185. Args:
  186. cls: The class to check.
  187. Returns:
  188. Whether the class is an Optional.
  189. """
  190. return is_union(cls) and type(None) in get_args(cls)
  191. def true_type_for_pydantic_field(f: ModelField):
  192. """Get the type for a pydantic field.
  193. Args:
  194. f: The field to get the type for.
  195. Returns:
  196. The type for the field.
  197. """
  198. if not isinstance(f.annotation, (str, ForwardRef)):
  199. return f.annotation
  200. type_ = f.outer_type_
  201. if (
  202. f.field_info.default is None
  203. or (isinstance(f.annotation, str) and f.annotation.startswith("Optional"))
  204. or (
  205. isinstance(f.annotation, ForwardRef)
  206. and f.annotation.__forward_arg__.startswith("Optional")
  207. )
  208. ) and not is_optional(type_):
  209. return type_ | None
  210. return type_
  211. def value_inside_optional(cls: GenericType) -> GenericType:
  212. """Get the value inside an Optional type or the original type.
  213. Args:
  214. cls: The class to check.
  215. Returns:
  216. The value inside the Optional type or the original type.
  217. """
  218. if is_union(cls) and len(args := get_args(cls)) >= 2 and type(None) in args:
  219. if len(args) == 2:
  220. return args[0] if args[1] is type(None) else args[1]
  221. return unionize(*[arg for arg in args if arg is not type(None)])
  222. return cls
  223. def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
  224. """Get the type of a field in a class.
  225. Args:
  226. cls: The class to check.
  227. field_name: The name of the field to check.
  228. Returns:
  229. The type of the field, if it exists, else None.
  230. """
  231. if (
  232. hasattr(cls, "__fields__")
  233. and field_name in cls.__fields__
  234. and hasattr(cls.__fields__[field_name], "annotation")
  235. and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
  236. ):
  237. return cls.__fields__[field_name].annotation
  238. type_hints = get_type_hints(cls)
  239. return type_hints.get(field_name, None)
  240. def get_property_hint(attr: Any | None) -> GenericType | None:
  241. """Check if an attribute is a property and return its type hint.
  242. Args:
  243. attr: The descriptor to check.
  244. Returns:
  245. The type hint of the property, if it is a property, else None.
  246. """
  247. from sqlalchemy.ext.hybrid import hybrid_property
  248. if not isinstance(attr, (property, hybrid_property)):
  249. return None
  250. hints = get_type_hints(attr.fget)
  251. return hints.get("return", None)
  252. def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
  253. """Check if an attribute can be accessed on the cls and return its type.
  254. Supports pydantic models, unions, and annotated attributes on rx.Model.
  255. Args:
  256. cls: The class to check.
  257. name: The name of the attribute to check.
  258. Returns:
  259. The type of the attribute, if accessible, or None
  260. """
  261. import sqlalchemy
  262. from sqlalchemy.ext.associationproxy import AssociationProxyInstance
  263. from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
  264. from reflex.model import Model
  265. try:
  266. attr = getattr(cls, name, None)
  267. except NotImplementedError:
  268. attr = None
  269. if hint := get_property_hint(attr):
  270. return hint
  271. if hasattr(cls, "__fields__") and name in cls.__fields__:
  272. # pydantic models
  273. return get_field_type(cls, name)
  274. elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
  275. insp = sqlalchemy.inspect(cls)
  276. if name in insp.columns:
  277. # check for list types
  278. column = insp.columns[name]
  279. column_type = column.type
  280. try:
  281. type_ = insp.columns[name].type.python_type
  282. except NotImplementedError:
  283. type_ = None
  284. if type_ is not None:
  285. if hasattr(column_type, "item_type"):
  286. try:
  287. item_type = column_type.item_type.python_type # pyright: ignore [reportAttributeAccessIssue]
  288. except NotImplementedError:
  289. item_type = None
  290. if item_type is not None:
  291. if type_ in PrimitiveToAnnotation:
  292. type_ = PrimitiveToAnnotation[type_]
  293. type_ = type_[item_type] # pyright: ignore [reportIndexIssue]
  294. if column.nullable:
  295. type_ = type_ | None
  296. return type_
  297. if name in insp.all_orm_descriptors:
  298. descriptor = insp.all_orm_descriptors[name]
  299. if hint := get_property_hint(descriptor):
  300. return hint
  301. if isinstance(descriptor, QueryableAttribute):
  302. prop = descriptor.property
  303. if isinstance(prop, Relationship):
  304. type_ = prop.mapper.class_
  305. # TODO: check for nullable?
  306. type_ = list[type_] if prop.uselist else type_ | None
  307. return type_
  308. if isinstance(attr, AssociationProxyInstance):
  309. return list[
  310. get_attribute_access_type(
  311. attr.target_class,
  312. attr.remote_attr.key, # type: ignore[attr-defined]
  313. )
  314. ]
  315. elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
  316. # Check in the annotations directly (for sqlmodel.Relationship)
  317. hints = get_type_hints(cls)
  318. if name in hints:
  319. type_ = hints[name]
  320. type_origin = get_origin(type_)
  321. if isinstance(type_origin, type) and issubclass(type_origin, Mapped):
  322. return get_args(type_)[0] # SQLAlchemy v2
  323. if isinstance(type_, ModelField):
  324. return type_.type_ # SQLAlchemy v1.4
  325. return type_
  326. elif is_union(cls):
  327. # Check in each arg of the annotation.
  328. return unionize(
  329. *(get_attribute_access_type(arg, name) for arg in get_args(cls))
  330. )
  331. elif isinstance(cls, type):
  332. # Bare class
  333. exceptions = NameError
  334. try:
  335. hints = get_type_hints(cls)
  336. if name in hints:
  337. return hints[name]
  338. except exceptions as e:
  339. console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}")
  340. pass
  341. return None # Attribute is not accessible.
  342. @lru_cache
  343. def get_base_class(cls: GenericType) -> type:
  344. """Get the base class of a class.
  345. Args:
  346. cls: The class.
  347. Returns:
  348. The base class of the class.
  349. Raises:
  350. TypeError: If a literal has multiple types.
  351. """
  352. if is_literal(cls):
  353. # only literals of the same type are supported.
  354. arg_type = type(get_args(cls)[0])
  355. if not all(type(arg) is arg_type for arg in get_args(cls)):
  356. raise TypeError("only literals of the same type are supported")
  357. return type(get_args(cls)[0])
  358. if is_union(cls):
  359. return tuple(get_base_class(arg) for arg in get_args(cls)) # pyright: ignore [reportReturnType]
  360. return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls
  361. def _breakpoints_satisfies_typing(cls_check: GenericType, instance: Any) -> bool:
  362. """Check if the breakpoints instance satisfies the typing.
  363. Args:
  364. cls_check: The class to check against.
  365. instance: The instance to check.
  366. Returns:
  367. Whether the breakpoints instance satisfies the typing.
  368. """
  369. cls_check_base = get_base_class(cls_check)
  370. if cls_check_base == Breakpoints:
  371. _, expected_type = get_args(cls_check)
  372. if is_literal(expected_type):
  373. for value in instance.values():
  374. if not isinstance(value, str) or value not in get_args(expected_type):
  375. return False
  376. return True
  377. elif isinstance(cls_check_base, tuple):
  378. # union type, so check all types
  379. return any(
  380. _breakpoints_satisfies_typing(type_to_check, instance)
  381. for type_to_check in get_args(cls_check)
  382. )
  383. elif cls_check_base == reflex.vars.Var and "__args__" in cls_check.__dict__:
  384. return _breakpoints_satisfies_typing(get_args(cls_check)[0], instance)
  385. return False
  386. def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None) -> bool:
  387. """Check if a class is a subclass of another class.
  388. Args:
  389. cls: The class to check.
  390. cls_check: The class to check against.
  391. instance: An instance of cls to aid in checking generics.
  392. Returns:
  393. Whether the class is a subclass of the other class.
  394. Raises:
  395. TypeError: If the base class is not valid for issubclass.
  396. """
  397. # Special check for Any.
  398. if cls_check == Any:
  399. return True
  400. if cls in [Any, Callable, None]:
  401. return False
  402. # Get the base classes.
  403. cls_base = get_base_class(cls)
  404. cls_check_base = get_base_class(cls_check)
  405. # The class we're checking should not be a union.
  406. if isinstance(cls_base, tuple):
  407. return False
  408. # Check that fields of breakpoints match the expected values.
  409. if isinstance(instance, Breakpoints):
  410. return _breakpoints_satisfies_typing(cls_check, instance)
  411. if isinstance(cls_check_base, tuple):
  412. cls_check_base = tuple(
  413. cls_check_one if not is_typeddict(cls_check_one) else dict
  414. for cls_check_one in cls_check_base
  415. )
  416. if is_typeddict(cls_check_base):
  417. cls_check_base = dict
  418. # Check if the types match.
  419. try:
  420. return cls_check_base == Any or issubclass(cls_base, cls_check_base)
  421. except TypeError as te:
  422. # These errors typically arise from bad annotations and are hard to
  423. # debug without knowing the type that we tried to compare.
  424. raise TypeError(f"Invalid type for issubclass: {cls_base}") from te
  425. def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
  426. """Check if an object satisfies a typed dict.
  427. Args:
  428. obj: The object to check.
  429. cls: The typed dict to check against.
  430. Returns:
  431. Whether the object satisfies the typed dict.
  432. """
  433. if not isinstance(obj, Mapping):
  434. return False
  435. key_names_to_values = get_type_hints(cls)
  436. required_keys: frozenset[str] = getattr(cls, "__required_keys__", frozenset())
  437. if not all(
  438. isinstance(key, str)
  439. and key in key_names_to_values
  440. and _isinstance(value, key_names_to_values[key])
  441. for key, value in obj.items()
  442. ):
  443. return False
  444. # TODO in 3.14: Implement https://peps.python.org/pep-0728/ if it's approved
  445. # required keys are all present
  446. return required_keys.issubset(required_keys)
  447. def _isinstance(
  448. obj: Any,
  449. cls: GenericType,
  450. *,
  451. nested: int = 0,
  452. treat_var_as_type: bool = True,
  453. treat_mutable_obj_as_immutable: bool = False,
  454. ) -> bool:
  455. """Check if an object is an instance of a class.
  456. Args:
  457. obj: The object to check.
  458. cls: The class to check against.
  459. nested: How many levels deep to check.
  460. treat_var_as_type: Whether to treat Var as the type it represents, i.e. _var_type.
  461. treat_mutable_obj_as_immutable: Whether to treat mutable objects as immutable. Useful if a component declares a mutable object as a prop, but the value is not expected to change.
  462. Returns:
  463. Whether the object is an instance of the class.
  464. """
  465. if cls is Any:
  466. return True
  467. from reflex.vars import LiteralVar, Var
  468. if cls is Var:
  469. return isinstance(obj, Var)
  470. if isinstance(obj, LiteralVar):
  471. return treat_var_as_type and _isinstance(
  472. obj._var_value, cls, nested=nested, treat_var_as_type=True
  473. )
  474. if isinstance(obj, Var):
  475. return treat_var_as_type and typehint_issubclass(
  476. obj._var_type,
  477. cls,
  478. treat_mutable_superclasss_as_immutable=treat_mutable_obj_as_immutable,
  479. treat_literals_as_union_of_types=True,
  480. treat_any_as_subtype_of_everything=True,
  481. )
  482. if cls is None or cls is type(None):
  483. return obj is None
  484. if cls is not None and is_union(cls):
  485. return any(
  486. _isinstance(obj, arg, nested=nested, treat_var_as_type=treat_var_as_type)
  487. for arg in get_args(cls)
  488. )
  489. if is_literal(cls):
  490. return obj in get_args(cls)
  491. origin = get_origin(cls)
  492. if origin is None:
  493. # cls is a typed dict
  494. if is_typeddict(cls):
  495. if nested:
  496. return does_obj_satisfy_typed_dict(obj, cls)
  497. return isinstance(obj, dict)
  498. # cls is a float
  499. if cls is float:
  500. return isinstance(obj, (float, int))
  501. # cls is a simple class
  502. return isinstance(obj, cls)
  503. args = get_args(cls)
  504. if not args:
  505. if treat_mutable_obj_as_immutable:
  506. if origin is dict:
  507. origin = Mapping
  508. elif origin is list or origin is set:
  509. origin = Sequence
  510. # cls is a simple generic class
  511. return isinstance(obj, origin)
  512. if origin is Var and args:
  513. # cls is a Var
  514. return _isinstance(
  515. obj,
  516. args[0],
  517. nested=nested,
  518. treat_var_as_type=treat_var_as_type,
  519. treat_mutable_obj_as_immutable=treat_mutable_obj_as_immutable,
  520. )
  521. if nested > 0 and args:
  522. if origin is list:
  523. expected_class = Sequence if treat_mutable_obj_as_immutable else list
  524. return isinstance(obj, expected_class) and all(
  525. _isinstance(
  526. item,
  527. args[0],
  528. nested=nested - 1,
  529. treat_var_as_type=treat_var_as_type,
  530. )
  531. for item in obj
  532. )
  533. if origin is tuple:
  534. if args[-1] is Ellipsis:
  535. return isinstance(obj, tuple) and all(
  536. _isinstance(
  537. item,
  538. args[0],
  539. nested=nested - 1,
  540. treat_var_as_type=treat_var_as_type,
  541. )
  542. for item in obj
  543. )
  544. return (
  545. isinstance(obj, tuple)
  546. and len(obj) == len(args)
  547. and all(
  548. _isinstance(
  549. item,
  550. arg,
  551. nested=nested - 1,
  552. treat_var_as_type=treat_var_as_type,
  553. )
  554. for item, arg in zip(obj, args, strict=True)
  555. )
  556. )
  557. if origin in (dict, Mapping, Breakpoints):
  558. expected_class = (
  559. dict
  560. if origin is dict and not treat_mutable_obj_as_immutable
  561. else Mapping
  562. )
  563. return isinstance(obj, expected_class) and all(
  564. _isinstance(
  565. key, args[0], nested=nested - 1, treat_var_as_type=treat_var_as_type
  566. )
  567. and _isinstance(
  568. value,
  569. args[1],
  570. nested=nested - 1,
  571. treat_var_as_type=treat_var_as_type,
  572. )
  573. for key, value in obj.items()
  574. )
  575. if origin is set:
  576. expected_class = Sequence if treat_mutable_obj_as_immutable else set
  577. return isinstance(obj, expected_class) and all(
  578. _isinstance(
  579. item,
  580. args[0],
  581. nested=nested - 1,
  582. treat_var_as_type=treat_var_as_type,
  583. )
  584. for item in obj
  585. )
  586. if args:
  587. from reflex.vars import Field
  588. if origin is Field:
  589. return _isinstance(
  590. obj, args[0], nested=nested, treat_var_as_type=treat_var_as_type
  591. )
  592. return isinstance(obj, get_base_class(cls))
  593. def is_dataframe(value: type) -> bool:
  594. """Check if the given value is a dataframe.
  595. Args:
  596. value: The value to check.
  597. Returns:
  598. Whether the value is a dataframe.
  599. """
  600. if is_generic_alias(value) or value == Any:
  601. return False
  602. return value.__name__ == "DataFrame"
  603. def is_valid_var_type(type_: type) -> bool:
  604. """Check if the given type is a valid prop type.
  605. Args:
  606. type_: The type to check.
  607. Returns:
  608. Whether the type is a valid prop type.
  609. """
  610. from reflex.utils import serializers
  611. if is_union(type_):
  612. return all(is_valid_var_type(arg) for arg in get_args(type_))
  613. return (
  614. _issubclass(type_, StateVar)
  615. or serializers.has_serializer(type_)
  616. or dataclasses.is_dataclass(type_)
  617. )
  618. def is_backend_base_variable(name: str, cls: type) -> bool:
  619. """Check if this variable name correspond to a backend variable.
  620. Args:
  621. name: The name of the variable to check
  622. cls: The class of the variable to check
  623. Returns:
  624. bool: The result of the check
  625. """
  626. if name in RESERVED_BACKEND_VAR_NAMES:
  627. return False
  628. if not name.startswith("_"):
  629. return False
  630. if name.startswith("__"):
  631. return False
  632. if name.startswith(f"_{cls.__name__}__"):
  633. return False
  634. # Extract the namespace of the original module if defined (dynamic substates).
  635. if callable(getattr(cls, "_get_type_hints", None)):
  636. hints = cls._get_type_hints()
  637. else:
  638. hints = get_type_hints(cls)
  639. if name in hints:
  640. hint = get_origin(hints[name])
  641. if hint == ClassVar:
  642. return False
  643. if name in cls.inherited_backend_vars:
  644. return False
  645. from reflex.vars.base import is_computed_var
  646. if name in cls.__dict__:
  647. value = cls.__dict__[name]
  648. if type(value) is classmethod:
  649. return False
  650. if callable(value):
  651. return False
  652. if isinstance(
  653. value,
  654. (
  655. types.FunctionType,
  656. property,
  657. cached_property,
  658. ),
  659. ) or is_computed_var(value):
  660. return False
  661. return True
  662. def check_type_in_allowed_types(value_type: type, allowed_types: Iterable) -> bool:
  663. """Check that a value type is found in a list of allowed types.
  664. Args:
  665. value_type: Type of value.
  666. allowed_types: Iterable of allowed types.
  667. Returns:
  668. If the type is found in the allowed types.
  669. """
  670. return get_base_class(value_type) in allowed_types
  671. def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
  672. """Check that a prop value is in a list of allowed types.
  673. Does the check in a way that works regardless if it's a raw value or a state Var.
  674. Args:
  675. prop: The prop to check.
  676. allowed_types: The list of allowed types.
  677. Returns:
  678. If the prop type match one of the allowed_types.
  679. """
  680. from reflex.vars import Var
  681. type_ = prop._var_type if isinstance(prop, Var) else type(prop)
  682. return type_ in allowed_types
  683. def is_encoded_fstring(value: Any) -> bool:
  684. """Check if a value is an encoded Var f-string.
  685. Args:
  686. value: The value string to check.
  687. Returns:
  688. Whether the value is an f-string
  689. """
  690. return isinstance(value, str) and constants.REFLEX_VAR_OPENING_TAG in value
  691. def validate_literal(key: str, value: Any, expected_type: type, comp_name: str):
  692. """Check that a value is a valid literal.
  693. Args:
  694. key: The prop name.
  695. value: The prop value to validate.
  696. expected_type: The expected type(literal type).
  697. comp_name: Name of the component.
  698. Raises:
  699. ValueError: When the value is not a valid literal.
  700. """
  701. from reflex.vars import Var
  702. if (
  703. is_literal(expected_type)
  704. and not isinstance(value, Var) # validating vars is not supported yet.
  705. and not is_encoded_fstring(value) # f-strings are not supported.
  706. and value not in expected_type.__args__
  707. ):
  708. allowed_values = expected_type.__args__
  709. if value not in allowed_values:
  710. allowed_value_str = ",".join(
  711. [str(v) if not isinstance(v, str) else f"'{v}'" for v in allowed_values]
  712. )
  713. value_str = f"'{value}'" if isinstance(value, str) else value
  714. raise ValueError(
  715. f"prop value for {key!s} of the `{comp_name}` component should be one of the following: {allowed_value_str}. Got {value_str} instead"
  716. )
  717. def validate_parameter_literals(func: Callable):
  718. """Decorator to check that the arguments passed to a function
  719. correspond to the correct function parameter if it (the parameter)
  720. is a literal type.
  721. Args:
  722. func: The function to validate.
  723. Returns:
  724. The wrapper function.
  725. """
  726. console.deprecate(
  727. "validate_parameter_literals",
  728. reason="Use manual validation instead.",
  729. deprecation_version="0.7.11",
  730. removal_version="0.8.0",
  731. dedupe=True,
  732. )
  733. func_params = list(inspect.signature(func).parameters.items())
  734. annotations = {param[0]: param[1].annotation for param in func_params}
  735. @wraps(func)
  736. def wrapper(*args, **kwargs):
  737. # validate args
  738. for param, arg in zip(annotations, args, strict=False):
  739. if annotations[param] is inspect.Parameter.empty:
  740. continue
  741. validate_literal(param, arg, annotations[param], func.__name__)
  742. # validate kwargs.
  743. for key, value in kwargs.items():
  744. annotation = annotations.get(key)
  745. if not annotation or annotation is inspect.Parameter.empty:
  746. continue
  747. validate_literal(key, value, annotation, func.__name__)
  748. return func(*args, **kwargs)
  749. return wrapper
  750. # Store this here for performance.
  751. StateBases = get_base_class(StateVar)
  752. StateIterBases = get_base_class(StateIterVar)
  753. def safe_issubclass(cls: Any, cls_check: Any | tuple[Any, ...]):
  754. """Check if a class is a subclass of another class. Returns False if internal error occurs.
  755. Args:
  756. cls: The class to check.
  757. cls_check: The class to check against.
  758. Returns:
  759. Whether the class is a subclass of the other class.
  760. """
  761. try:
  762. return issubclass(cls, cls_check)
  763. except TypeError:
  764. return False
  765. def typehint_issubclass(
  766. possible_subclass: Any,
  767. possible_superclass: Any,
  768. *,
  769. treat_mutable_superclasss_as_immutable: bool = False,
  770. treat_literals_as_union_of_types: bool = True,
  771. treat_any_as_subtype_of_everything: bool = False,
  772. ) -> bool:
  773. """Check if a type hint is a subclass of another type hint.
  774. Args:
  775. possible_subclass: The type hint to check.
  776. possible_superclass: The type hint to check against.
  777. treat_mutable_superclasss_as_immutable: Whether to treat target classes as immutable.
  778. treat_literals_as_union_of_types: Whether to treat literals as a union of their types.
  779. treat_any_as_subtype_of_everything: Whether to treat Any as a subtype of everything. This is the default behavior in Python.
  780. Returns:
  781. Whether the type hint is a subclass of the other type hint.
  782. """
  783. if possible_superclass is Any:
  784. return True
  785. if possible_subclass is Any:
  786. return treat_any_as_subtype_of_everything
  787. if possible_subclass is NoReturn:
  788. return True
  789. provided_type_origin = get_origin(possible_subclass)
  790. accepted_type_origin = get_origin(possible_superclass)
  791. if provided_type_origin is None and accepted_type_origin is None:
  792. # In this case, we are dealing with a non-generic type, so we can use issubclass
  793. return issubclass(possible_subclass, possible_superclass)
  794. if treat_literals_as_union_of_types and is_literal(possible_superclass):
  795. args = get_args(possible_superclass)
  796. return any(
  797. typehint_issubclass(
  798. possible_subclass,
  799. type(arg),
  800. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  801. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  802. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  803. )
  804. for arg in args
  805. )
  806. if is_literal(possible_subclass):
  807. args = get_args(possible_subclass)
  808. return all(
  809. _isinstance(
  810. arg,
  811. possible_superclass,
  812. treat_mutable_obj_as_immutable=treat_mutable_superclasss_as_immutable,
  813. nested=2,
  814. )
  815. for arg in args
  816. )
  817. provided_type_origin = (
  818. Union if provided_type_origin is types.UnionType else provided_type_origin
  819. )
  820. accepted_type_origin = (
  821. Union if accepted_type_origin is types.UnionType else accepted_type_origin
  822. )
  823. # Get type arguments (e.g., [float, int] for dict[float, int])
  824. provided_args = get_args(possible_subclass)
  825. accepted_args = get_args(possible_superclass)
  826. if accepted_type_origin is Union:
  827. if provided_type_origin is not Union:
  828. return any(
  829. typehint_issubclass(
  830. possible_subclass,
  831. accepted_arg,
  832. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  833. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  834. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  835. )
  836. for accepted_arg in accepted_args
  837. )
  838. return all(
  839. any(
  840. typehint_issubclass(
  841. provided_arg,
  842. accepted_arg,
  843. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  844. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  845. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  846. )
  847. for accepted_arg in accepted_args
  848. )
  849. for provided_arg in provided_args
  850. )
  851. if provided_type_origin is Union:
  852. return all(
  853. typehint_issubclass(
  854. provided_arg,
  855. possible_superclass,
  856. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  857. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  858. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  859. )
  860. for provided_arg in provided_args
  861. )
  862. provided_type_origin = provided_type_origin or possible_subclass
  863. accepted_type_origin = accepted_type_origin or possible_superclass
  864. if treat_mutable_superclasss_as_immutable:
  865. if accepted_type_origin is dict:
  866. accepted_type_origin = Mapping
  867. elif accepted_type_origin is list or accepted_type_origin is set:
  868. accepted_type_origin = Sequence
  869. # Check if the origin of both types is the same (e.g., list for list[int])
  870. if not safe_issubclass(
  871. provided_type_origin or possible_subclass,
  872. accepted_type_origin or possible_superclass,
  873. ):
  874. return False
  875. # Ensure all specific types are compatible with accepted types
  876. # Note this is not necessarily correct, as it doesn't check against contravariance and covariance
  877. # It also ignores when the length of the arguments is different
  878. return all(
  879. typehint_issubclass(
  880. provided_arg,
  881. accepted_arg,
  882. treat_mutable_superclasss_as_immutable=treat_mutable_superclasss_as_immutable,
  883. treat_literals_as_union_of_types=treat_literals_as_union_of_types,
  884. treat_any_as_subtype_of_everything=treat_any_as_subtype_of_everything,
  885. )
  886. for provided_arg, accepted_arg in zip(
  887. provided_args, accepted_args, strict=False
  888. )
  889. if accepted_arg is not Any
  890. )