base.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. """Collection of base classes."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import functools
  5. import inspect
  6. import sys
  7. from typing import (
  8. TYPE_CHECKING,
  9. Any,
  10. Callable,
  11. Optional,
  12. Type,
  13. TypeVar,
  14. overload,
  15. )
  16. from typing_extensions import ParamSpec, get_origin
  17. from reflex import constants
  18. from reflex.base import Base
  19. from reflex.utils import serializers, types
  20. from reflex.utils.exceptions import VarTypeError
  21. from reflex.vars import (
  22. ImmutableVarData,
  23. Var,
  24. VarData,
  25. _decode_var_immutable,
  26. _extract_var_data,
  27. _global_vars,
  28. )
  29. if TYPE_CHECKING:
  30. from .function import FunctionVar, ToFunctionOperation
  31. from .number import (
  32. BooleanVar,
  33. NumberVar,
  34. ToBooleanVarOperation,
  35. ToNumberVarOperation,
  36. )
  37. from .object import ObjectVar, ToObjectOperation
  38. from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
  39. @dataclasses.dataclass(
  40. eq=False,
  41. frozen=True,
  42. **{"slots": True} if sys.version_info >= (3, 10) else {},
  43. )
  44. class ImmutableVar(Var):
  45. """Base class for immutable vars."""
  46. # The name of the var.
  47. _var_name: str = dataclasses.field()
  48. # The type of the var.
  49. _var_type: types.GenericType = dataclasses.field(default=Any)
  50. # Extra metadata associated with the Var
  51. _var_data: Optional[ImmutableVarData] = dataclasses.field(default=None)
  52. def __str__(self) -> str:
  53. """String representation of the var. Guaranteed to be a valid Javascript expression.
  54. Returns:
  55. The name of the var.
  56. """
  57. return self._var_name
  58. @property
  59. def _var_is_local(self) -> bool:
  60. """Whether this is a local javascript variable.
  61. Returns:
  62. False
  63. """
  64. return False
  65. @property
  66. def _var_is_string(self) -> bool:
  67. """Whether the var is a string literal.
  68. Returns:
  69. False
  70. """
  71. return False
  72. @property
  73. def _var_full_name_needs_state_prefix(self) -> bool:
  74. """Whether the full name of the var needs a _var_state prefix.
  75. Returns:
  76. False
  77. """
  78. return False
  79. def __post_init__(self):
  80. """Post-initialize the var."""
  81. # Decode any inline Var markup and apply it to the instance
  82. _var_data, _var_name = _decode_var_immutable(self._var_name)
  83. if _var_data or _var_name != self._var_name:
  84. self.__init__(
  85. _var_name=_var_name,
  86. _var_type=self._var_type,
  87. _var_data=ImmutableVarData.merge(self._var_data, _var_data),
  88. )
  89. def __hash__(self) -> int:
  90. """Define a hash function for the var.
  91. Returns:
  92. The hash of the var.
  93. """
  94. return hash((self._var_name, self._var_type, self._var_data))
  95. def _get_all_var_data(self) -> ImmutableVarData | None:
  96. """Get all VarData associated with the Var.
  97. Returns:
  98. The VarData of the components and all of its children.
  99. """
  100. return self._var_data
  101. def _replace(self, merge_var_data=None, **kwargs: Any):
  102. """Make a copy of this Var with updated fields.
  103. Args:
  104. merge_var_data: VarData to merge into the existing VarData.
  105. **kwargs: Var fields to update.
  106. Returns:
  107. A new ImmutableVar with the updated fields overwriting the corresponding fields in this Var.
  108. Raises:
  109. TypeError: If _var_is_local, _var_is_string, or _var_full_name_needs_state_prefix is not None.
  110. """
  111. if kwargs.get("_var_is_local", False) is not False:
  112. raise TypeError(
  113. "The _var_is_local argument is not supported for ImmutableVar."
  114. )
  115. if kwargs.get("_var_is_string", False) is not False:
  116. raise TypeError(
  117. "The _var_is_string argument is not supported for ImmutableVar."
  118. )
  119. if kwargs.get("_var_full_name_needs_state_prefix", False) is not False:
  120. raise TypeError(
  121. "The _var_full_name_needs_state_prefix argument is not supported for ImmutableVar."
  122. )
  123. field_values = dict(
  124. _var_name=kwargs.pop("_var_name", self._var_name),
  125. _var_type=kwargs.pop("_var_type", self._var_type),
  126. _var_data=ImmutableVarData.merge(
  127. kwargs.get("_var_data", self._var_data), merge_var_data
  128. ),
  129. )
  130. return type(self)(**field_values)
  131. @classmethod
  132. def create(
  133. cls,
  134. value: Any,
  135. _var_is_local: bool | None = None,
  136. _var_is_string: bool | None = None,
  137. _var_data: VarData | None = None,
  138. ) -> ImmutableVar | Var | None:
  139. """Create a var from a value.
  140. Args:
  141. value: The value to create the var from.
  142. _var_is_local: Whether the var is local. Deprecated.
  143. _var_is_string: Whether the var is a string literal. Deprecated.
  144. _var_data: Additional hooks and imports associated with the Var.
  145. Returns:
  146. The var.
  147. Raises:
  148. VarTypeError: If the value is JSON-unserializable.
  149. TypeError: If _var_is_local or _var_is_string is not None.
  150. """
  151. if _var_is_local is not None:
  152. raise TypeError(
  153. "The _var_is_local argument is not supported for ImmutableVar."
  154. )
  155. if _var_is_string is not None:
  156. raise TypeError(
  157. "The _var_is_string argument is not supported for ImmutableVar."
  158. )
  159. from reflex.utils import format
  160. # Check for none values.
  161. if value is None:
  162. return None
  163. # If the value is already a var, do nothing.
  164. if isinstance(value, Var):
  165. return value
  166. # Try to pull the imports and hooks from contained values.
  167. if not isinstance(value, str):
  168. _var_data = VarData.merge(*_extract_var_data(value), _var_data)
  169. # Try to serialize the value.
  170. type_ = type(value)
  171. if type_ in types.JSONType:
  172. name = value
  173. else:
  174. name, _serialized_type = serializers.serialize(value, get_type=True)
  175. if name is None:
  176. raise VarTypeError(
  177. f"No JSON serializer found for var {value} of type {type_}."
  178. )
  179. name = name if isinstance(name, str) else format.json_dumps(name)
  180. return cls(
  181. _var_name=name,
  182. _var_type=type_,
  183. _var_data=(
  184. ImmutableVarData(
  185. state=_var_data.state,
  186. imports=_var_data.imports,
  187. hooks=_var_data.hooks,
  188. )
  189. if _var_data
  190. else None
  191. ),
  192. )
  193. @classmethod
  194. def create_safe(
  195. cls,
  196. value: Any,
  197. _var_is_local: bool | None = None,
  198. _var_is_string: bool | None = None,
  199. _var_data: VarData | None = None,
  200. ) -> Var | ImmutableVar:
  201. """Create a var from a value, asserting that it is not None.
  202. Args:
  203. value: The value to create the var from.
  204. _var_is_local: Whether the var is local. Deprecated.
  205. _var_is_string: Whether the var is a string literal. Deprecated.
  206. _var_data: Additional hooks and imports associated with the Var.
  207. Returns:
  208. The var.
  209. """
  210. var = cls.create(
  211. value,
  212. _var_is_local=_var_is_local,
  213. _var_is_string=_var_is_string,
  214. _var_data=_var_data,
  215. )
  216. assert var is not None
  217. return var
  218. def __format__(self, format_spec: str) -> str:
  219. """Format the var into a Javascript equivalent to an f-string.
  220. Args:
  221. format_spec: The format specifier (Ignored for now).
  222. Returns:
  223. The formatted var.
  224. """
  225. hashed_var = hash(self)
  226. _global_vars[hashed_var] = self
  227. # Encode the _var_data into the formatted output for tracking purposes.
  228. return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}"
  229. @overload
  230. def to(
  231. self, output: Type[NumberVar], var_type: type[int] | type[float] = float
  232. ) -> ToNumberVarOperation: ...
  233. @overload
  234. def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ...
  235. @overload
  236. def to(
  237. self,
  238. output: Type[ArrayVar],
  239. var_type: type[list] | type[tuple] | type[set] = list,
  240. ) -> ToArrayOperation: ...
  241. @overload
  242. def to(self, output: Type[StringVar]) -> ToStringOperation: ...
  243. @overload
  244. def to(
  245. self, output: Type[ObjectVar], var_type: types.GenericType = dict
  246. ) -> ToObjectOperation: ...
  247. @overload
  248. def to(
  249. self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
  250. ) -> ToFunctionOperation: ...
  251. @overload
  252. def to(
  253. self, output: Type[OUTPUT], var_type: types.GenericType | None = None
  254. ) -> OUTPUT: ...
  255. def to(
  256. self, output: Type[OUTPUT], var_type: types.GenericType | None = None
  257. ) -> Var:
  258. """Convert the var to a different type.
  259. Args:
  260. output: The output type.
  261. var_type: The type of the var.
  262. Raises:
  263. TypeError: If the var_type is not a supported type for the output.
  264. Returns:
  265. The converted var.
  266. """
  267. from .number import (
  268. BooleanVar,
  269. NumberVar,
  270. ToBooleanVarOperation,
  271. ToNumberVarOperation,
  272. )
  273. fixed_type = (
  274. var_type
  275. if var_type is None or inspect.isclass(var_type)
  276. else get_origin(var_type)
  277. )
  278. if issubclass(output, NumberVar):
  279. if fixed_type is not None and not issubclass(fixed_type, (int, float)):
  280. raise TypeError(
  281. f"Unsupported type {var_type} for NumberVar. Must be int or float."
  282. )
  283. return ToNumberVarOperation(self, var_type or float)
  284. if issubclass(output, BooleanVar):
  285. return ToBooleanVarOperation(self)
  286. from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
  287. if issubclass(output, ArrayVar):
  288. if fixed_type is not None and not issubclass(
  289. fixed_type, (list, tuple, set)
  290. ):
  291. raise TypeError(
  292. f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set."
  293. )
  294. return ToArrayOperation(self, var_type or list)
  295. if issubclass(output, StringVar):
  296. return ToStringOperation(self)
  297. from .object import ObjectVar, ToObjectOperation
  298. if issubclass(output, ObjectVar):
  299. return ToObjectOperation(self, var_type or dict)
  300. from .function import FunctionVar, ToFunctionOperation
  301. if issubclass(output, FunctionVar):
  302. if fixed_type is not None and not issubclass(fixed_type, Callable):
  303. raise TypeError(
  304. f"Unsupported type {var_type} for FunctionVar. Must be Callable."
  305. )
  306. return ToFunctionOperation(self, var_type or Callable)
  307. return output(
  308. _var_name=self._var_name,
  309. _var_type=self._var_type if var_type is None else var_type,
  310. _var_data=self._var_data,
  311. )
  312. def guess_type(self) -> ImmutableVar:
  313. """Guess the type of the var.
  314. Returns:
  315. The guessed type.
  316. """
  317. from .number import NumberVar
  318. from .object import ObjectVar
  319. from .sequence import ArrayVar, StringVar
  320. if self._var_type is Any:
  321. return self
  322. var_type = self._var_type
  323. fixed_type = var_type if inspect.isclass(var_type) else get_origin(var_type)
  324. if issubclass(fixed_type, (int, float)):
  325. return self.to(NumberVar, var_type)
  326. if issubclass(fixed_type, dict):
  327. return self.to(ObjectVar, var_type)
  328. if issubclass(fixed_type, (list, tuple, set)):
  329. return self.to(ArrayVar, var_type)
  330. if issubclass(fixed_type, str):
  331. return self.to(StringVar)
  332. return self
  333. OUTPUT = TypeVar("OUTPUT", bound=ImmutableVar)
  334. class LiteralVar(ImmutableVar):
  335. """Base class for immutable literal vars."""
  336. @classmethod
  337. def create(
  338. cls,
  339. value: Any,
  340. _var_data: VarData | None = None,
  341. ) -> Var:
  342. """Create a var from a value.
  343. Args:
  344. value: The value to create the var from.
  345. _var_data: Additional hooks and imports associated with the Var.
  346. Returns:
  347. The var.
  348. Raises:
  349. TypeError: If the value is not a supported type for LiteralVar.
  350. """
  351. if isinstance(value, Var):
  352. if _var_data is None:
  353. return value
  354. return value._replace(merge_var_data=_var_data)
  355. if value is None:
  356. return ImmutableVar.create_safe("null", _var_data=_var_data)
  357. from .object import LiteralObjectVar
  358. if isinstance(value, Base):
  359. return LiteralObjectVar(
  360. value.dict(), _var_type=type(value), _var_data=_var_data
  361. )
  362. from .number import LiteralBooleanVar, LiteralNumberVar
  363. from .sequence import LiteralArrayVar, LiteralStringVar
  364. if isinstance(value, str):
  365. return LiteralStringVar.create(value, _var_data=_var_data)
  366. type_mapping = {
  367. int: LiteralNumberVar,
  368. float: LiteralNumberVar,
  369. bool: LiteralBooleanVar,
  370. dict: LiteralObjectVar,
  371. list: LiteralArrayVar,
  372. tuple: LiteralArrayVar,
  373. set: LiteralArrayVar,
  374. }
  375. constructor = type_mapping.get(type(value))
  376. if constructor is None:
  377. raise TypeError(f"Unsupported type {type(value)} for LiteralVar.")
  378. return constructor(value, _var_data=_var_data)
  379. def __post_init__(self):
  380. """Post-initialize the var."""
  381. def json(self) -> str:
  382. """Serialize the var to a JSON string.
  383. Raises:
  384. NotImplementedError: If the method is not implemented.
  385. """
  386. raise NotImplementedError(
  387. "LiteralVar subclasses must implement the json method."
  388. )
  389. P = ParamSpec("P")
  390. T = TypeVar("T", bound=ImmutableVar)
  391. def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P, T]]:
  392. """Decorator for creating a var operation.
  393. Example:
  394. ```python
  395. @var_operation(output=NumberVar)
  396. def add(a: NumberVar, b: NumberVar):
  397. return f"({a} + {b})"
  398. ```
  399. Args:
  400. output: The output type of the operation.
  401. Returns:
  402. The decorator.
  403. """
  404. def decorator(func: Callable[P, str], output=output):
  405. @functools.wraps(func)
  406. def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
  407. args_vars = [
  408. LiteralVar.create(arg) if not isinstance(arg, Var) else arg
  409. for arg in args
  410. ]
  411. kwargs_vars = {
  412. key: LiteralVar.create(value) if not isinstance(value, Var) else value
  413. for key, value in kwargs.items()
  414. }
  415. return output(
  416. _var_name=func(*args_vars, **kwargs_vars), # type: ignore
  417. _var_data=VarData.merge(
  418. *[arg._get_all_var_data() for arg in args if isinstance(arg, Var)],
  419. *[
  420. arg._get_all_var_data()
  421. for arg in kwargs.values()
  422. if isinstance(arg, Var)
  423. ],
  424. ),
  425. )
  426. return wrapper
  427. return decorator