function.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. """Immutable function vars."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import sys
  5. from functools import cached_property
  6. from typing import Any, Callable, Optional, Tuple, Type, Union
  7. from reflex.experimental.vars.base import ImmutableVar, LiteralVar
  8. from reflex.vars import ImmutableVarData, Var, VarData
  9. class FunctionVar(ImmutableVar):
  10. """Base class for immutable function vars."""
  11. def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
  12. """Call the function with the given arguments.
  13. Args:
  14. *args: The arguments to call the function with.
  15. Returns:
  16. The function call operation.
  17. """
  18. return ArgsFunctionOperation(
  19. ("...args",),
  20. VarOperationCall(self, *args, ImmutableVar.create_safe("...args")),
  21. )
  22. def call(self, *args: Var | Any) -> VarOperationCall:
  23. """Call the function with the given arguments.
  24. Args:
  25. *args: The arguments to call the function with.
  26. Returns:
  27. The function call operation.
  28. """
  29. return VarOperationCall(self, *args)
  30. class FunctionStringVar(FunctionVar):
  31. """Base class for immutable function vars from a string."""
  32. def __init__(self, func: str, _var_data: VarData | None = None) -> None:
  33. """Initialize the function var.
  34. Args:
  35. func: The function to call.
  36. _var_data: Additional hooks and imports associated with the Var.
  37. """
  38. super(FunctionVar, self).__init__(
  39. _var_name=func,
  40. _var_type=Callable,
  41. _var_data=ImmutableVarData.merge(_var_data),
  42. )
  43. @dataclasses.dataclass(
  44. eq=False,
  45. frozen=True,
  46. **{"slots": True} if sys.version_info >= (3, 10) else {},
  47. )
  48. class VarOperationCall(ImmutableVar):
  49. """Base class for immutable vars that are the result of a function call."""
  50. _func: Optional[FunctionVar] = dataclasses.field(default=None)
  51. _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
  52. def __init__(
  53. self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None
  54. ):
  55. """Initialize the function call var.
  56. Args:
  57. func: The function to call.
  58. *args: The arguments to call the function with.
  59. _var_data: Additional hooks and imports associated with the Var.
  60. """
  61. super(VarOperationCall, self).__init__(
  62. _var_name="",
  63. _var_type=Any,
  64. _var_data=ImmutableVarData.merge(_var_data),
  65. )
  66. object.__setattr__(self, "_func", func)
  67. object.__setattr__(self, "_args", args)
  68. object.__delattr__(self, "_var_name")
  69. def __getattr__(self, name):
  70. """Get an attribute of the var.
  71. Args:
  72. name: The name of the attribute.
  73. Returns:
  74. The attribute of the var.
  75. """
  76. if name == "_var_name":
  77. return self._cached_var_name
  78. return super(type(self), self).__getattr__(name)
  79. @cached_property
  80. def _cached_var_name(self) -> str:
  81. """The name of the var.
  82. Returns:
  83. The name of the var.
  84. """
  85. return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
  86. @cached_property
  87. def _cached_get_all_var_data(self) -> ImmutableVarData | None:
  88. """Get all VarData associated with the Var.
  89. Returns:
  90. The VarData of the components and all of its children.
  91. """
  92. return ImmutableVarData.merge(
  93. self._func._get_all_var_data() if self._func is not None else None,
  94. *[var._get_all_var_data() for var in self._args],
  95. self._var_data,
  96. )
  97. def _get_all_var_data(self) -> ImmutableVarData | None:
  98. """Wrapper method for cached property.
  99. Returns:
  100. The VarData of the components and all of its children.
  101. """
  102. return self._cached_get_all_var_data
  103. def __post_init__(self):
  104. """Post-initialize the var."""
  105. pass
  106. @dataclasses.dataclass(
  107. eq=False,
  108. frozen=True,
  109. **{"slots": True} if sys.version_info >= (3, 10) else {},
  110. )
  111. class ArgsFunctionOperation(FunctionVar):
  112. """Base class for immutable function defined via arguments and return expression."""
  113. _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
  114. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  115. def __init__(
  116. self,
  117. args_names: Tuple[str, ...],
  118. return_expr: Var | Any,
  119. _var_data: VarData | None = None,
  120. ) -> None:
  121. """Initialize the function with arguments var.
  122. Args:
  123. args_names: The names of the arguments.
  124. return_expr: The return expression of the function.
  125. _var_data: Additional hooks and imports associated with the Var.
  126. """
  127. super(ArgsFunctionOperation, self).__init__(
  128. _var_name=f"",
  129. _var_type=Callable,
  130. _var_data=ImmutableVarData.merge(_var_data),
  131. )
  132. object.__setattr__(self, "_args_names", args_names)
  133. object.__setattr__(self, "_return_expr", return_expr)
  134. object.__delattr__(self, "_var_name")
  135. def __getattr__(self, name):
  136. """Get an attribute of the var.
  137. Args:
  138. name: The name of the attribute.
  139. Returns:
  140. The attribute of the var.
  141. """
  142. if name == "_var_name":
  143. return self._cached_var_name
  144. return super(type(self), self).__getattr__(name)
  145. @cached_property
  146. def _cached_var_name(self) -> str:
  147. """The name of the var.
  148. Returns:
  149. The name of the var.
  150. """
  151. return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
  152. @cached_property
  153. def _cached_get_all_var_data(self) -> ImmutableVarData | None:
  154. """Get all VarData associated with the Var.
  155. Returns:
  156. The VarData of the components and all of its children.
  157. """
  158. return ImmutableVarData.merge(
  159. self._return_expr._get_all_var_data(),
  160. self._var_data,
  161. )
  162. def _get_all_var_data(self) -> ImmutableVarData | None:
  163. """Wrapper method for cached property.
  164. Returns:
  165. The VarData of the components and all of its children.
  166. """
  167. return self._cached_get_all_var_data
  168. def __post_init__(self):
  169. """Post-initialize the var."""
  170. @dataclasses.dataclass(
  171. eq=False,
  172. frozen=True,
  173. **{"slots": True} if sys.version_info >= (3, 10) else {},
  174. )
  175. class ToFunctionOperation(FunctionVar):
  176. """Base class of converting a var to a function."""
  177. _original_var: Var = dataclasses.field(
  178. default_factory=lambda: LiteralVar.create(None)
  179. )
  180. def __init__(
  181. self,
  182. original_var: Var,
  183. _var_type: Type[Callable] = Callable,
  184. _var_data: VarData | None = None,
  185. ) -> None:
  186. """Initialize the function with arguments var.
  187. Args:
  188. original_var: The original var to convert to a function.
  189. _var_type: The type of the function.
  190. _var_data: Additional hooks and imports associated with the Var.
  191. """
  192. super(ToFunctionOperation, self).__init__(
  193. _var_name=f"",
  194. _var_type=_var_type,
  195. _var_data=ImmutableVarData.merge(_var_data),
  196. )
  197. object.__setattr__(self, "_original_var", original_var)
  198. object.__delattr__(self, "_var_name")
  199. def __getattr__(self, name):
  200. """Get an attribute of the var.
  201. Args:
  202. name: The name of the attribute.
  203. Returns:
  204. The attribute of the var.
  205. """
  206. if name == "_var_name":
  207. return self._cached_var_name
  208. return super(type(self), self).__getattr__(name)
  209. @cached_property
  210. def _cached_var_name(self) -> str:
  211. """The name of the var.
  212. Returns:
  213. The name of the var.
  214. """
  215. return str(self._original_var)
  216. @cached_property
  217. def _cached_get_all_var_data(self) -> ImmutableVarData | None:
  218. """Get all VarData associated with the Var.
  219. Returns:
  220. The VarData of the components and all of its children.
  221. """
  222. return ImmutableVarData.merge(
  223. self._original_var._get_all_var_data(),
  224. self._var_data,
  225. )
  226. def _get_all_var_data(self) -> ImmutableVarData | None:
  227. """Wrapper method for cached property.
  228. Returns:
  229. The VarData of the components and all of its children.
  230. """
  231. return self._cached_get_all_var_data