function.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. """Immutable function vars."""
  2. from __future__ import annotations
  3. import dataclasses
  4. from collections.abc import Callable, Sequence
  5. from typing import Any, Concatenate, Generic, ParamSpec, Protocol, TypeVar, overload
  6. from reflex.utils import format
  7. from reflex.utils.types import GenericType
  8. from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
  9. P = ParamSpec("P")
  10. V1 = TypeVar("V1")
  11. V2 = TypeVar("V2")
  12. V3 = TypeVar("V3")
  13. V4 = TypeVar("V4")
  14. V5 = TypeVar("V5")
  15. V6 = TypeVar("V6")
  16. R = TypeVar("R")
  17. class ReflexCallable(Protocol[P, R]):
  18. """Protocol for a callable."""
  19. __call__: Callable[P, R]
  20. CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, covariant=True)
  21. OTHER_CALLABLE_TYPE = TypeVar(
  22. "OTHER_CALLABLE_TYPE", bound=ReflexCallable, covariant=True
  23. )
  24. class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
  25. """Base class for immutable function vars."""
  26. @overload
  27. def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
  28. @overload
  29. def partial(
  30. self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
  31. arg1: V1 | Var[V1],
  32. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  33. @overload
  34. def partial(
  35. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
  36. arg1: V1 | Var[V1],
  37. arg2: V2 | Var[V2],
  38. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  39. @overload
  40. def partial(
  41. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
  42. arg1: V1 | Var[V1],
  43. arg2: V2 | Var[V2],
  44. arg3: V3 | Var[V3],
  45. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  46. @overload
  47. def partial(
  48. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
  49. arg1: V1 | Var[V1],
  50. arg2: V2 | Var[V2],
  51. arg3: V3 | Var[V3],
  52. arg4: V4 | Var[V4],
  53. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  54. @overload
  55. def partial(
  56. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
  57. arg1: V1 | Var[V1],
  58. arg2: V2 | Var[V2],
  59. arg3: V3 | Var[V3],
  60. arg4: V4 | Var[V4],
  61. arg5: V5 | Var[V5],
  62. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  63. @overload
  64. def partial(
  65. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
  66. arg1: V1 | Var[V1],
  67. arg2: V2 | Var[V2],
  68. arg3: V3 | Var[V3],
  69. arg4: V4 | Var[V4],
  70. arg5: V5 | Var[V5],
  71. arg6: V6 | Var[V6],
  72. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  73. @overload
  74. def partial(
  75. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  76. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  77. @overload
  78. def partial(self, *args: Var | Any) -> FunctionVar: ...
  79. def partial(self, *args: Var | Any) -> FunctionVar: # pyright: ignore [reportInconsistentOverload]
  80. """Partially apply the function with the given arguments.
  81. Args:
  82. *args: The arguments to partially apply the function with.
  83. Returns:
  84. The partially applied function.
  85. """
  86. if not args:
  87. return ArgsFunctionOperation.create((), self)
  88. return ArgsFunctionOperation.create(
  89. ("...args",),
  90. VarOperationCall.create(self, *args, Var(_js_expr="...args")),
  91. )
  92. @overload
  93. def call(
  94. self: FunctionVar[ReflexCallable[[V1], R]], arg1: V1 | Var[V1]
  95. ) -> VarOperationCall[[V1], R]: ...
  96. @overload
  97. def call(
  98. self: FunctionVar[ReflexCallable[[V1, V2], R]],
  99. arg1: V1 | Var[V1],
  100. arg2: V2 | Var[V2],
  101. ) -> VarOperationCall[[V1, V2], R]: ...
  102. @overload
  103. def call(
  104. self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
  105. arg1: V1 | Var[V1],
  106. arg2: V2 | Var[V2],
  107. arg3: V3 | Var[V3],
  108. ) -> VarOperationCall[[V1, V2, V3], R]: ...
  109. @overload
  110. def call(
  111. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
  112. arg1: V1 | Var[V1],
  113. arg2: V2 | Var[V2],
  114. arg3: V3 | Var[V3],
  115. arg4: V4 | Var[V4],
  116. ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
  117. @overload
  118. def call(
  119. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
  120. arg1: V1 | Var[V1],
  121. arg2: V2 | Var[V2],
  122. arg3: V3 | Var[V3],
  123. arg4: V4 | Var[V4],
  124. arg5: V5 | Var[V5],
  125. ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
  126. @overload
  127. def call(
  128. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
  129. arg1: V1 | Var[V1],
  130. arg2: V2 | Var[V2],
  131. arg3: V3 | Var[V3],
  132. arg4: V4 | Var[V4],
  133. arg5: V5 | Var[V5],
  134. arg6: V6 | Var[V6],
  135. ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
  136. @overload
  137. def call(
  138. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  139. ) -> VarOperationCall[P, R]: ...
  140. @overload
  141. def call(self, *args: Var | Any) -> Var: ...
  142. def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload]
  143. """Call the function with the given arguments.
  144. Args:
  145. *args: The arguments to call the function with.
  146. Returns:
  147. The function call operation.
  148. """
  149. return VarOperationCall.create(self, *args).guess_type()
  150. __call__ = call
  151. class BuilderFunctionVar(
  152. FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
  153. ):
  154. """Base class for immutable function vars with the builder pattern."""
  155. __call__ = FunctionVar.partial
  156. class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
  157. """Base class for immutable function vars from a string."""
  158. @classmethod
  159. def create(
  160. cls,
  161. func: str,
  162. _var_type: type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
  163. _var_data: VarData | None = None,
  164. ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
  165. """Create a new function var from a string.
  166. Args:
  167. func: The function to call.
  168. _var_type: The type of the Var.
  169. _var_data: Additional hooks and imports associated with the Var.
  170. Returns:
  171. The function var.
  172. """
  173. return FunctionStringVar(
  174. _js_expr=func,
  175. _var_type=_var_type,
  176. _var_data=_var_data,
  177. )
  178. @dataclasses.dataclass(
  179. eq=False,
  180. frozen=True,
  181. slots=True,
  182. )
  183. class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
  184. """Base class for immutable vars that are the result of a function call."""
  185. _func: FunctionVar[ReflexCallable[P, R]] | None = dataclasses.field(default=None)
  186. _args: tuple[Var | Any, ...] = dataclasses.field(default_factory=tuple)
  187. @cached_property_no_lock
  188. def _cached_var_name(self) -> str:
  189. """The name of the var.
  190. Returns:
  191. The name of the var.
  192. """
  193. return f"({self._func!s}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
  194. @cached_property_no_lock
  195. def _cached_get_all_var_data(self) -> VarData | None:
  196. """Get all the var data associated with the var.
  197. Returns:
  198. All the var data associated with the var.
  199. """
  200. return VarData.merge(
  201. self._func._get_all_var_data() if self._func is not None else None,
  202. *[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
  203. self._var_data,
  204. )
  205. @classmethod
  206. def create(
  207. cls,
  208. func: FunctionVar[ReflexCallable[P, R]],
  209. *args: Var | Any,
  210. _var_type: GenericType = Any,
  211. _var_data: VarData | None = None,
  212. ) -> VarOperationCall:
  213. """Create a new function call var.
  214. Args:
  215. func: The function to call.
  216. *args: The arguments to call the function with.
  217. _var_type: The type of the Var.
  218. _var_data: Additional hooks and imports associated with the Var.
  219. Returns:
  220. The function call var.
  221. """
  222. function_return_type = (
  223. func._var_type.__args__[1]
  224. if getattr(func._var_type, "__args__", None)
  225. else Any
  226. )
  227. var_type = _var_type if _var_type is not Any else function_return_type
  228. return cls(
  229. _js_expr="",
  230. _var_type=var_type,
  231. _var_data=_var_data,
  232. _func=func,
  233. _args=args,
  234. )
  235. @dataclasses.dataclass(frozen=True)
  236. class DestructuredArg:
  237. """Class for destructured arguments."""
  238. fields: tuple[str, ...] = ()
  239. rest: str | None = None
  240. def to_javascript(self) -> str:
  241. """Convert the destructured argument to JavaScript.
  242. Returns:
  243. The destructured argument in JavaScript.
  244. """
  245. return format.wrap(
  246. ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
  247. "{",
  248. "}",
  249. )
  250. @dataclasses.dataclass(
  251. frozen=True,
  252. )
  253. class FunctionArgs:
  254. """Class for function arguments."""
  255. args: tuple[str | DestructuredArg, ...] = ()
  256. rest: str | None = None
  257. def format_args_function_operation(
  258. args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
  259. ) -> str:
  260. """Format an args function operation.
  261. Args:
  262. args: The function arguments.
  263. return_expr: The return expression.
  264. explicit_return: Whether to use explicit return syntax.
  265. Returns:
  266. The formatted args function operation.
  267. """
  268. arg_names_str = ", ".join(
  269. [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
  270. ) + (f", ...{args.rest}" if args.rest else "")
  271. return_expr_str = str(LiteralVar.create(return_expr))
  272. # Wrap return expression in curly braces if explicit return syntax is used.
  273. return_expr_str_wrapped = (
  274. format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
  275. )
  276. return f"(({arg_names_str}) => {return_expr_str_wrapped})"
  277. @dataclasses.dataclass(
  278. eq=False,
  279. frozen=True,
  280. slots=True,
  281. )
  282. class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
  283. """Base class for immutable function defined via arguments and return expression."""
  284. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  285. _return_expr: Var | Any = dataclasses.field(default=None)
  286. _explicit_return: bool = dataclasses.field(default=False)
  287. @cached_property_no_lock
  288. def _cached_var_name(self) -> str:
  289. """The name of the var.
  290. Returns:
  291. The name of the var.
  292. """
  293. return format_args_function_operation(
  294. self._args, self._return_expr, self._explicit_return
  295. )
  296. @classmethod
  297. def create(
  298. cls,
  299. args_names: Sequence[str | DestructuredArg],
  300. return_expr: Var | Any,
  301. rest: str | None = None,
  302. explicit_return: bool = False,
  303. _var_type: GenericType = Callable,
  304. _var_data: VarData | None = None,
  305. ):
  306. """Create a new function var.
  307. Args:
  308. args_names: The names of the arguments.
  309. return_expr: The return expression of the function.
  310. rest: The name of the rest argument.
  311. explicit_return: Whether to use explicit return syntax.
  312. _var_type: The type of the Var.
  313. _var_data: Additional hooks and imports associated with the Var.
  314. Returns:
  315. The function var.
  316. """
  317. return_expr = Var.create(return_expr)
  318. return cls(
  319. _js_expr="",
  320. _var_type=_var_type,
  321. _var_data=_var_data,
  322. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  323. _return_expr=return_expr,
  324. _explicit_return=explicit_return,
  325. )
  326. @dataclasses.dataclass(
  327. eq=False,
  328. frozen=True,
  329. slots=True,
  330. )
  331. class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
  332. """Base class for immutable function defined via arguments and return expression with the builder pattern."""
  333. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  334. _return_expr: Var | Any = dataclasses.field(default=None)
  335. _explicit_return: bool = dataclasses.field(default=False)
  336. @cached_property_no_lock
  337. def _cached_var_name(self) -> str:
  338. """The name of the var.
  339. Returns:
  340. The name of the var.
  341. """
  342. return format_args_function_operation(
  343. self._args, self._return_expr, self._explicit_return
  344. )
  345. @classmethod
  346. def create(
  347. cls,
  348. args_names: Sequence[str | DestructuredArg],
  349. return_expr: Var | Any,
  350. rest: str | None = None,
  351. explicit_return: bool = False,
  352. _var_type: GenericType = Callable,
  353. _var_data: VarData | None = None,
  354. ):
  355. """Create a new function var.
  356. Args:
  357. args_names: The names of the arguments.
  358. return_expr: The return expression of the function.
  359. rest: The name of the rest argument.
  360. explicit_return: Whether to use explicit return syntax.
  361. _var_type: The type of the Var.
  362. _var_data: Additional hooks and imports associated with the Var.
  363. Returns:
  364. The function var.
  365. """
  366. return_expr = Var.create(return_expr)
  367. return cls(
  368. _js_expr="",
  369. _var_type=_var_type,
  370. _var_data=_var_data,
  371. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  372. _return_expr=return_expr,
  373. _explicit_return=explicit_return,
  374. )
  375. JSON_STRINGIFY = FunctionStringVar.create(
  376. "JSON.stringify", _var_type=ReflexCallable[[Any], str]
  377. )
  378. ARRAY_ISARRAY = FunctionStringVar.create(
  379. "Array.isArray", _var_type=ReflexCallable[[Any], bool]
  380. )
  381. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  382. "((__to_string) => __to_string.toString())",
  383. _var_type=ReflexCallable[[Any], str],
  384. )