base.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. """Collection of base classes."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import json
  5. import re
  6. import sys
  7. from functools import cached_property
  8. from typing import Any, Optional, Type
  9. from reflex import constants
  10. from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
  11. from reflex.utils import serializers, types
  12. from reflex.utils.exceptions import VarTypeError
  13. from reflex.vars import (
  14. ImmutableVarData,
  15. Var,
  16. VarData,
  17. _decode_var_immutable,
  18. _extract_var_data,
  19. _global_vars,
  20. )
  21. @dataclasses.dataclass(
  22. eq=False,
  23. frozen=True,
  24. **{"slots": True} if sys.version_info >= (3, 10) else {},
  25. )
  26. class ImmutableVar(Var):
  27. """Base class for immutable vars."""
  28. # The name of the var.
  29. _var_name: str = dataclasses.field()
  30. # The type of the var.
  31. _var_type: Type = dataclasses.field(default=Any)
  32. # Extra metadata associated with the Var
  33. _var_data: Optional[ImmutableVarData] = dataclasses.field(default=None)
  34. def __str__(self) -> str:
  35. """String representation of the var. Guaranteed to be a valid Javascript expression.
  36. Returns:
  37. The name of the var.
  38. """
  39. return self._var_name
  40. @property
  41. def _var_is_local(self) -> bool:
  42. """Whether this is a local javascript variable.
  43. Returns:
  44. False
  45. """
  46. return False
  47. @property
  48. def _var_is_string(self) -> bool:
  49. """Whether the var is a string literal.
  50. Returns:
  51. False
  52. """
  53. return False
  54. @property
  55. def _var_full_name_needs_state_prefix(self) -> bool:
  56. """Whether the full name of the var needs a _var_state prefix.
  57. Returns:
  58. False
  59. """
  60. return False
  61. def __post_init__(self):
  62. """Post-initialize the var."""
  63. # Decode any inline Var markup and apply it to the instance
  64. _var_data, _var_name = _decode_var_immutable(self._var_name)
  65. if _var_data:
  66. self.__init__(
  67. _var_name,
  68. self._var_type,
  69. ImmutableVarData.merge(self._var_data, _var_data),
  70. )
  71. def __hash__(self) -> int:
  72. """Define a hash function for the var.
  73. Returns:
  74. The hash of the var.
  75. """
  76. return hash((self._var_name, self._var_type, self._var_data))
  77. def _get_all_var_data(self) -> ImmutableVarData | None:
  78. return self._var_data
  79. def _replace(self, merge_var_data=None, **kwargs: Any):
  80. """Make a copy of this Var with updated fields.
  81. Args:
  82. merge_var_data: VarData to merge into the existing VarData.
  83. **kwargs: Var fields to update.
  84. Returns:
  85. A new ImmutableVar with the updated fields overwriting the corresponding fields in this Var.
  86. Raises:
  87. TypeError: If _var_is_local, _var_is_string, or _var_full_name_needs_state_prefix is not None.
  88. """
  89. if kwargs.get("_var_is_local", False) is not False:
  90. raise TypeError(
  91. "The _var_is_local argument is not supported for ImmutableVar."
  92. )
  93. if kwargs.get("_var_is_string", False) is not False:
  94. raise TypeError(
  95. "The _var_is_string argument is not supported for ImmutableVar."
  96. )
  97. if kwargs.get("_var_full_name_needs_state_prefix", False) is not False:
  98. raise TypeError(
  99. "The _var_full_name_needs_state_prefix argument is not supported for ImmutableVar."
  100. )
  101. field_values = dict(
  102. _var_name=kwargs.pop("_var_name", self._var_name),
  103. _var_type=kwargs.pop("_var_type", self._var_type),
  104. _var_data=ImmutableVarData.merge(
  105. kwargs.get("_var_data", self._var_data), merge_var_data
  106. ),
  107. )
  108. return type(self)(**field_values)
  109. @classmethod
  110. def create(
  111. cls,
  112. value: Any,
  113. _var_is_local: bool | None = None,
  114. _var_is_string: bool | None = None,
  115. _var_data: VarData | None = None,
  116. ) -> ImmutableVar | Var | None:
  117. """Create a var from a value.
  118. Args:
  119. value: The value to create the var from.
  120. _var_is_local: Whether the var is local. Deprecated.
  121. _var_is_string: Whether the var is a string literal. Deprecated.
  122. _var_data: Additional hooks and imports associated with the Var.
  123. Returns:
  124. The var.
  125. Raises:
  126. VarTypeError: If the value is JSON-unserializable.
  127. TypeError: If _var_is_local or _var_is_string is not None.
  128. """
  129. if _var_is_local is not None:
  130. raise TypeError(
  131. "The _var_is_local argument is not supported for ImmutableVar."
  132. )
  133. if _var_is_string is not None:
  134. raise TypeError(
  135. "The _var_is_string argument is not supported for ImmutableVar."
  136. )
  137. from reflex.utils import format
  138. # Check for none values.
  139. if value is None:
  140. return None
  141. # If the value is already a var, do nothing.
  142. if isinstance(value, Var):
  143. return value
  144. # Try to pull the imports and hooks from contained values.
  145. if not isinstance(value, str):
  146. _var_data = VarData.merge(*_extract_var_data(value), _var_data)
  147. # Try to serialize the value.
  148. type_ = type(value)
  149. if type_ in types.JSONType:
  150. name = value
  151. else:
  152. name, _serialized_type = serializers.serialize(value, get_type=True)
  153. if name is None:
  154. raise VarTypeError(
  155. f"No JSON serializer found for var {value} of type {type_}."
  156. )
  157. name = name if isinstance(name, str) else format.json_dumps(name)
  158. return cls(
  159. _var_name=name,
  160. _var_type=type_,
  161. _var_data=(
  162. ImmutableVarData(
  163. state=_var_data.state,
  164. imports=_var_data.imports,
  165. hooks=_var_data.hooks,
  166. )
  167. if _var_data
  168. else None
  169. ),
  170. )
  171. @classmethod
  172. def create_safe(
  173. cls,
  174. value: Any,
  175. _var_is_local: bool | None = None,
  176. _var_is_string: bool | None = None,
  177. _var_data: VarData | None = None,
  178. ) -> Var | ImmutableVar:
  179. """Create a var from a value, asserting that it is not None.
  180. Args:
  181. value: The value to create the var from.
  182. _var_is_local: Whether the var is local. Deprecated.
  183. _var_is_string: Whether the var is a string literal. Deprecated.
  184. _var_data: Additional hooks and imports associated with the Var.
  185. Returns:
  186. The var.
  187. """
  188. var = cls.create(
  189. value,
  190. _var_is_local=_var_is_local,
  191. _var_is_string=_var_is_string,
  192. _var_data=_var_data,
  193. )
  194. assert var is not None
  195. return var
  196. def __format__(self, format_spec: str) -> str:
  197. """Format the var into a Javascript equivalent to an f-string.
  198. Args:
  199. format_spec: The format specifier (Ignored for now).
  200. Returns:
  201. The formatted var.
  202. """
  203. hashed_var = hash(self)
  204. _global_vars[hashed_var] = self
  205. # Encode the _var_data into the formatted output for tracking purposes.
  206. return f"{REFLEX_VAR_OPENING_TAG}{hashed_var}{REFLEX_VAR_CLOSING_TAG}{self._var_name}"
  207. class StringVar(ImmutableVar):
  208. """Base class for immutable string vars."""
  209. class NumberVar(ImmutableVar):
  210. """Base class for immutable number vars."""
  211. class BooleanVar(ImmutableVar):
  212. """Base class for immutable boolean vars."""
  213. class ObjectVar(ImmutableVar):
  214. """Base class for immutable object vars."""
  215. class ArrayVar(ImmutableVar):
  216. """Base class for immutable array vars."""
  217. class FunctionVar(ImmutableVar):
  218. """Base class for immutable function vars."""
  219. class LiteralVar(ImmutableVar):
  220. """Base class for immutable literal vars."""
  221. def __post_init__(self):
  222. """Post-initialize the var."""
  223. # Compile regex for finding reflex var tags.
  224. _decode_var_pattern_re = (
  225. rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
  226. )
  227. _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
  228. @dataclasses.dataclass(
  229. eq=False,
  230. frozen=True,
  231. **{"slots": True} if sys.version_info >= (3, 10) else {},
  232. )
  233. class LiteralStringVar(LiteralVar):
  234. """Base class for immutable literal string vars."""
  235. _var_value: Optional[str] = dataclasses.field(default=None)
  236. @classmethod
  237. def create(
  238. cls,
  239. value: str,
  240. _var_data: VarData | None = None,
  241. ) -> LiteralStringVar | ConcatVarOperation:
  242. """Create a var from a string value.
  243. Args:
  244. value: The value to create the var from.
  245. _var_data: Additional hooks and imports associated with the Var.
  246. Returns:
  247. The var.
  248. """
  249. if REFLEX_VAR_OPENING_TAG in value:
  250. strings_and_vals: list[Var] = []
  251. offset = 0
  252. # Initialize some methods for reading json.
  253. var_data_config = VarData().__config__
  254. def json_loads(s):
  255. try:
  256. return var_data_config.json_loads(s)
  257. except json.decoder.JSONDecodeError:
  258. return var_data_config.json_loads(
  259. var_data_config.json_loads(f'"{s}"')
  260. )
  261. # Find all tags.
  262. while m := _decode_var_pattern.search(value):
  263. start, end = m.span()
  264. if start > 0:
  265. strings_and_vals.append(LiteralStringVar.create(value[:start]))
  266. serialized_data = m.group(1)
  267. if serialized_data[1:].isnumeric():
  268. # This is a global immutable var.
  269. var = _global_vars[int(serialized_data)]
  270. strings_and_vals.append(var)
  271. value = value[(end + len(var._var_name)) :]
  272. else:
  273. data = json_loads(serialized_data)
  274. string_length = data.pop("string_length", None)
  275. var_data = VarData.parse_obj(data)
  276. # Use string length to compute positions of interpolations.
  277. if string_length is not None:
  278. realstart = start + offset
  279. var_data.interpolations = [
  280. (realstart, realstart + string_length)
  281. ]
  282. strings_and_vals.append(
  283. ImmutableVar.create_safe(
  284. value[end : (end + string_length)], _var_data=var_data
  285. )
  286. )
  287. value = value[(end + string_length) :]
  288. offset += end - start
  289. if value:
  290. strings_and_vals.append(LiteralStringVar.create(value))
  291. return ConcatVarOperation.create(
  292. tuple(strings_and_vals), _var_data=_var_data
  293. )
  294. return cls(
  295. _var_value=value,
  296. _var_name=f'"{value}"',
  297. _var_type=str,
  298. _var_data=ImmutableVarData.merge(_var_data),
  299. )
  300. @dataclasses.dataclass(
  301. eq=False,
  302. frozen=True,
  303. **{"slots": True} if sys.version_info >= (3, 10) else {},
  304. )
  305. class ConcatVarOperation(StringVar):
  306. """Representing a concatenation of literal string vars."""
  307. _var_value: tuple[Var, ...] = dataclasses.field(default_factory=tuple)
  308. def __init__(self, _var_value: tuple[Var, ...], _var_data: VarData | None = None):
  309. """Initialize the operation of concatenating literal string vars.
  310. Args:
  311. _var_value: The list of vars to concatenate.
  312. _var_data: Additional hooks and imports associated with the Var.
  313. """
  314. super(ConcatVarOperation, self).__init__(
  315. _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str
  316. )
  317. object.__setattr__(self, "_var_value", _var_value)
  318. object.__setattr__(self, "_var_name", self._cached_var_name)
  319. @cached_property
  320. def _cached_var_name(self) -> str:
  321. """The name of the var.
  322. Returns:
  323. The name of the var.
  324. """
  325. return "+".join([str(element) for element in self._var_value])
  326. @cached_property
  327. def _cached_get_all_var_data(self) -> ImmutableVarData | None:
  328. """Get all VarData associated with the Var.
  329. Returns:
  330. The VarData of the components and all of its children.
  331. """
  332. return ImmutableVarData.merge(
  333. *[var._get_all_var_data() for var in self._var_value], self._var_data
  334. )
  335. def _get_all_var_data(self) -> ImmutableVarData | None:
  336. """Wrapper method for cached property.
  337. Returns:
  338. The VarData of the components and all of its children.
  339. """
  340. return self._cached_get_all_var_data
  341. def __post_init__(self):
  342. """Post-initialize the var."""
  343. pass
  344. @classmethod
  345. def create(
  346. cls,
  347. value: tuple[Var, ...],
  348. _var_data: VarData | None = None,
  349. ) -> ConcatVarOperation:
  350. """Create a var from a tuple of values.
  351. Args:
  352. value: The value to create the var from.
  353. _var_data: Additional hooks and imports associated with the Var.
  354. Returns:
  355. The var.
  356. """
  357. return ConcatVarOperation(
  358. _var_value=value,
  359. _var_data=_var_data,
  360. )