function.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. """Immutable function vars."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import sys
  5. from typing import (
  6. Any,
  7. Callable,
  8. Concatenate,
  9. Generic,
  10. ParamSpec,
  11. Protocol,
  12. Sequence,
  13. Type,
  14. TypeVar,
  15. overload,
  16. )
  17. from reflex.utils import format
  18. from reflex.utils.types import GenericType
  19. from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
  20. P = ParamSpec("P")
  21. V1 = TypeVar("V1")
  22. V2 = TypeVar("V2")
  23. V3 = TypeVar("V3")
  24. V4 = TypeVar("V4")
  25. V5 = TypeVar("V5")
  26. V6 = TypeVar("V6")
  27. R = TypeVar("R")
  28. class ReflexCallable(Protocol[P, R]):
  29. """Protocol for a callable."""
  30. __call__: Callable[P, R]
  31. CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, covariant=True)
  32. OTHER_CALLABLE_TYPE = TypeVar(
  33. "OTHER_CALLABLE_TYPE", bound=ReflexCallable, covariant=True
  34. )
  35. class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
  36. """Base class for immutable function vars."""
  37. @overload
  38. def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
  39. @overload
  40. def partial(
  41. self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
  42. arg1: V1 | Var[V1],
  43. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  44. @overload
  45. def partial(
  46. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
  47. arg1: V1 | Var[V1],
  48. arg2: V2 | Var[V2],
  49. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  50. @overload
  51. def partial(
  52. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
  53. arg1: V1 | Var[V1],
  54. arg2: V2 | Var[V2],
  55. arg3: V3 | Var[V3],
  56. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  57. @overload
  58. def partial(
  59. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
  60. arg1: V1 | Var[V1],
  61. arg2: V2 | Var[V2],
  62. arg3: V3 | Var[V3],
  63. arg4: V4 | Var[V4],
  64. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  65. @overload
  66. def partial(
  67. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
  68. arg1: V1 | Var[V1],
  69. arg2: V2 | Var[V2],
  70. arg3: V3 | Var[V3],
  71. arg4: V4 | Var[V4],
  72. arg5: V5 | Var[V5],
  73. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  74. @overload
  75. def partial(
  76. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
  77. arg1: V1 | Var[V1],
  78. arg2: V2 | Var[V2],
  79. arg3: V3 | Var[V3],
  80. arg4: V4 | Var[V4],
  81. arg5: V5 | Var[V5],
  82. arg6: V6 | Var[V6],
  83. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  84. @overload
  85. def partial(
  86. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  87. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  88. @overload
  89. def partial(self, *args: Var | Any) -> FunctionVar: ...
  90. def partial(self, *args: Var | Any) -> FunctionVar: # pyright: ignore [reportInconsistentOverload]
  91. """Partially apply the function with the given arguments.
  92. Args:
  93. *args: The arguments to partially apply the function with.
  94. Returns:
  95. The partially applied function.
  96. """
  97. if not args:
  98. return ArgsFunctionOperation.create((), self)
  99. return ArgsFunctionOperation.create(
  100. ("...args",),
  101. VarOperationCall.create(self, *args, Var(_js_expr="...args")),
  102. )
  103. @overload
  104. def call(
  105. self: FunctionVar[ReflexCallable[[V1], R]], arg1: V1 | Var[V1]
  106. ) -> VarOperationCall[[V1], R]: ...
  107. @overload
  108. def call(
  109. self: FunctionVar[ReflexCallable[[V1, V2], R]],
  110. arg1: V1 | Var[V1],
  111. arg2: V2 | Var[V2],
  112. ) -> VarOperationCall[[V1, V2], R]: ...
  113. @overload
  114. def call(
  115. self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
  116. arg1: V1 | Var[V1],
  117. arg2: V2 | Var[V2],
  118. arg3: V3 | Var[V3],
  119. ) -> VarOperationCall[[V1, V2, V3], R]: ...
  120. @overload
  121. def call(
  122. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
  123. arg1: V1 | Var[V1],
  124. arg2: V2 | Var[V2],
  125. arg3: V3 | Var[V3],
  126. arg4: V4 | Var[V4],
  127. ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
  128. @overload
  129. def call(
  130. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
  131. arg1: V1 | Var[V1],
  132. arg2: V2 | Var[V2],
  133. arg3: V3 | Var[V3],
  134. arg4: V4 | Var[V4],
  135. arg5: V5 | Var[V5],
  136. ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
  137. @overload
  138. def call(
  139. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
  140. arg1: V1 | Var[V1],
  141. arg2: V2 | Var[V2],
  142. arg3: V3 | Var[V3],
  143. arg4: V4 | Var[V4],
  144. arg5: V5 | Var[V5],
  145. arg6: V6 | Var[V6],
  146. ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
  147. @overload
  148. def call(
  149. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  150. ) -> VarOperationCall[P, R]: ...
  151. @overload
  152. def call(self, *args: Var | Any) -> Var: ...
  153. def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload]
  154. """Call the function with the given arguments.
  155. Args:
  156. *args: The arguments to call the function with.
  157. Returns:
  158. The function call operation.
  159. """
  160. return VarOperationCall.create(self, *args).guess_type()
  161. __call__ = call
  162. class BuilderFunctionVar(
  163. FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
  164. ):
  165. """Base class for immutable function vars with the builder pattern."""
  166. __call__ = FunctionVar.partial
  167. class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
  168. """Base class for immutable function vars from a string."""
  169. @classmethod
  170. def create(
  171. cls,
  172. func: str,
  173. _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
  174. _var_data: VarData | None = None,
  175. ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
  176. """Create a new function var from a string.
  177. Args:
  178. func: The function to call.
  179. _var_type: The type of the Var.
  180. _var_data: Additional hooks and imports associated with the Var.
  181. Returns:
  182. The function var.
  183. """
  184. return FunctionStringVar(
  185. _js_expr=func,
  186. _var_type=_var_type,
  187. _var_data=_var_data,
  188. )
  189. @dataclasses.dataclass(
  190. eq=False,
  191. frozen=True,
  192. slots=True,
  193. )
  194. class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
  195. """Base class for immutable vars that are the result of a function call."""
  196. _func: FunctionVar[ReflexCallable[P, R]] | None = dataclasses.field(default=None)
  197. _args: tuple[Var | Any, ...] = dataclasses.field(default_factory=tuple)
  198. @cached_property_no_lock
  199. def _cached_var_name(self) -> str:
  200. """The name of the var.
  201. Returns:
  202. The name of the var.
  203. """
  204. return f"({self._func!s}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
  205. @cached_property_no_lock
  206. def _cached_get_all_var_data(self) -> VarData | None:
  207. """Get all the var data associated with the var.
  208. Returns:
  209. All the var data associated with the var.
  210. """
  211. return VarData.merge(
  212. self._func._get_all_var_data() if self._func is not None else None,
  213. *[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
  214. self._var_data,
  215. )
  216. @classmethod
  217. def create(
  218. cls,
  219. func: FunctionVar[ReflexCallable[P, R]],
  220. *args: Var | Any,
  221. _var_type: GenericType = Any,
  222. _var_data: VarData | None = None,
  223. ) -> VarOperationCall:
  224. """Create a new function call var.
  225. Args:
  226. func: The function to call.
  227. *args: The arguments to call the function with.
  228. _var_type: The type of the Var.
  229. _var_data: Additional hooks and imports associated with the Var.
  230. Returns:
  231. The function call var.
  232. """
  233. function_return_type = (
  234. func._var_type.__args__[1]
  235. if getattr(func._var_type, "__args__", None)
  236. else Any
  237. )
  238. var_type = _var_type if _var_type is not Any else function_return_type
  239. return cls(
  240. _js_expr="",
  241. _var_type=var_type,
  242. _var_data=_var_data,
  243. _func=func,
  244. _args=args,
  245. )
  246. @dataclasses.dataclass(frozen=True)
  247. class DestructuredArg:
  248. """Class for destructured arguments."""
  249. fields: tuple[str, ...] = ()
  250. rest: str | None = None
  251. def to_javascript(self) -> str:
  252. """Convert the destructured argument to JavaScript.
  253. Returns:
  254. The destructured argument in JavaScript.
  255. """
  256. return format.wrap(
  257. ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
  258. "{",
  259. "}",
  260. )
  261. @dataclasses.dataclass(
  262. frozen=True,
  263. )
  264. class FunctionArgs:
  265. """Class for function arguments."""
  266. args: tuple[str | DestructuredArg, ...] = ()
  267. rest: str | None = None
  268. def format_args_function_operation(
  269. args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
  270. ) -> str:
  271. """Format an args function operation.
  272. Args:
  273. args: The function arguments.
  274. return_expr: The return expression.
  275. explicit_return: Whether to use explicit return syntax.
  276. Returns:
  277. The formatted args function operation.
  278. """
  279. arg_names_str = ", ".join(
  280. [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
  281. ) + (f", ...{args.rest}" if args.rest else "")
  282. return_expr_str = str(LiteralVar.create(return_expr))
  283. # Wrap return expression in curly braces if explicit return syntax is used.
  284. return_expr_str_wrapped = (
  285. format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
  286. )
  287. return f"(({arg_names_str}) => {return_expr_str_wrapped})"
  288. @dataclasses.dataclass(
  289. eq=False,
  290. frozen=True,
  291. slots=True,
  292. )
  293. class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
  294. """Base class for immutable function defined via arguments and return expression."""
  295. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  296. _return_expr: Var | Any = dataclasses.field(default=None)
  297. _explicit_return: bool = dataclasses.field(default=False)
  298. @cached_property_no_lock
  299. def _cached_var_name(self) -> str:
  300. """The name of the var.
  301. Returns:
  302. The name of the var.
  303. """
  304. return format_args_function_operation(
  305. self._args, self._return_expr, self._explicit_return
  306. )
  307. @classmethod
  308. def create(
  309. cls,
  310. args_names: Sequence[str | DestructuredArg],
  311. return_expr: Var | Any,
  312. rest: str | None = None,
  313. explicit_return: bool = False,
  314. _var_type: GenericType = Callable,
  315. _var_data: VarData | None = None,
  316. ):
  317. """Create a new function var.
  318. Args:
  319. args_names: The names of the arguments.
  320. return_expr: The return expression of the function.
  321. rest: The name of the rest argument.
  322. explicit_return: Whether to use explicit return syntax.
  323. _var_type: The type of the Var.
  324. _var_data: Additional hooks and imports associated with the Var.
  325. Returns:
  326. The function var.
  327. """
  328. return_expr = Var.create(return_expr)
  329. return cls(
  330. _js_expr="",
  331. _var_type=_var_type,
  332. _var_data=_var_data,
  333. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  334. _return_expr=return_expr,
  335. _explicit_return=explicit_return,
  336. )
  337. @dataclasses.dataclass(
  338. eq=False,
  339. frozen=True,
  340. slots=True,
  341. )
  342. class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
  343. """Base class for immutable function defined via arguments and return expression with the builder pattern."""
  344. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  345. _return_expr: Var | Any = dataclasses.field(default=None)
  346. _explicit_return: bool = dataclasses.field(default=False)
  347. @cached_property_no_lock
  348. def _cached_var_name(self) -> str:
  349. """The name of the var.
  350. Returns:
  351. The name of the var.
  352. """
  353. return format_args_function_operation(
  354. self._args, self._return_expr, self._explicit_return
  355. )
  356. @classmethod
  357. def create(
  358. cls,
  359. args_names: Sequence[str | DestructuredArg],
  360. return_expr: Var | Any,
  361. rest: str | None = None,
  362. explicit_return: bool = False,
  363. _var_type: GenericType = Callable,
  364. _var_data: VarData | None = None,
  365. ):
  366. """Create a new function var.
  367. Args:
  368. args_names: The names of the arguments.
  369. return_expr: The return expression of the function.
  370. rest: The name of the rest argument.
  371. explicit_return: Whether to use explicit return syntax.
  372. _var_type: The type of the Var.
  373. _var_data: Additional hooks and imports associated with the Var.
  374. Returns:
  375. The function var.
  376. """
  377. return_expr = Var.create(return_expr)
  378. return cls(
  379. _js_expr="",
  380. _var_type=_var_type,
  381. _var_data=_var_data,
  382. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  383. _return_expr=return_expr,
  384. _explicit_return=explicit_return,
  385. )
  386. if python_version := sys.version_info[:2] >= (3, 10):
  387. JSON_STRINGIFY = FunctionStringVar.create(
  388. "JSON.stringify", _var_type=ReflexCallable[[Any], str]
  389. )
  390. ARRAY_ISARRAY = FunctionStringVar.create(
  391. "Array.isArray", _var_type=ReflexCallable[[Any], bool]
  392. )
  393. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  394. "((__to_string) => __to_string.toString())",
  395. _var_type=ReflexCallable[[Any], str],
  396. )
  397. else:
  398. JSON_STRINGIFY = FunctionStringVar.create(
  399. "JSON.stringify", _var_type=ReflexCallable[Any, str]
  400. )
  401. ARRAY_ISARRAY = FunctionStringVar.create(
  402. "Array.isArray", _var_type=ReflexCallable[Any, bool]
  403. )
  404. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  405. "((__to_string) => __to_string.toString())",
  406. _var_type=ReflexCallable[Any, str],
  407. )