function.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. """Immutable function vars."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import sys
  5. from typing import (
  6. Any,
  7. Callable,
  8. Concatenate,
  9. Generic,
  10. Optional,
  11. ParamSpec,
  12. Protocol,
  13. Sequence,
  14. Tuple,
  15. Type,
  16. TypeVar,
  17. Union,
  18. overload,
  19. )
  20. from reflex.utils import format
  21. from reflex.utils.types import GenericType
  22. from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
  23. P = ParamSpec("P")
  24. V1 = TypeVar("V1")
  25. V2 = TypeVar("V2")
  26. V3 = TypeVar("V3")
  27. V4 = TypeVar("V4")
  28. V5 = TypeVar("V5")
  29. V6 = TypeVar("V6")
  30. R = TypeVar("R")
  31. class ReflexCallable(Protocol[P, R]):
  32. """Protocol for a callable."""
  33. __call__: Callable[P, R]
  34. CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, covariant=True)
  35. OTHER_CALLABLE_TYPE = TypeVar(
  36. "OTHER_CALLABLE_TYPE", bound=ReflexCallable, covariant=True
  37. )
  38. class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
  39. """Base class for immutable function vars."""
  40. @overload
  41. def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
  42. @overload
  43. def partial(
  44. self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
  45. arg1: Union[V1, Var[V1]],
  46. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  47. @overload
  48. def partial(
  49. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
  50. arg1: Union[V1, Var[V1]],
  51. arg2: Union[V2, Var[V2]],
  52. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  53. @overload
  54. def partial(
  55. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
  56. arg1: Union[V1, Var[V1]],
  57. arg2: Union[V2, Var[V2]],
  58. arg3: Union[V3, Var[V3]],
  59. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  60. @overload
  61. def partial(
  62. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
  63. arg1: Union[V1, Var[V1]],
  64. arg2: Union[V2, Var[V2]],
  65. arg3: Union[V3, Var[V3]],
  66. arg4: Union[V4, Var[V4]],
  67. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  68. @overload
  69. def partial(
  70. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
  71. arg1: Union[V1, Var[V1]],
  72. arg2: Union[V2, Var[V2]],
  73. arg3: Union[V3, Var[V3]],
  74. arg4: Union[V4, Var[V4]],
  75. arg5: Union[V5, Var[V5]],
  76. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  77. @overload
  78. def partial(
  79. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
  80. arg1: Union[V1, Var[V1]],
  81. arg2: Union[V2, Var[V2]],
  82. arg3: Union[V3, Var[V3]],
  83. arg4: Union[V4, Var[V4]],
  84. arg5: Union[V5, Var[V5]],
  85. arg6: Union[V6, Var[V6]],
  86. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  87. @overload
  88. def partial(
  89. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  90. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  91. @overload
  92. def partial(self, *args: Var | Any) -> FunctionVar: ...
  93. def partial(self, *args: Var | Any) -> FunctionVar: # pyright: ignore [reportInconsistentOverload]
  94. """Partially apply the function with the given arguments.
  95. Args:
  96. *args: The arguments to partially apply the function with.
  97. Returns:
  98. The partially applied function.
  99. """
  100. if not args:
  101. return ArgsFunctionOperation.create((), self)
  102. return ArgsFunctionOperation.create(
  103. ("...args",),
  104. VarOperationCall.create(self, *args, Var(_js_expr="...args")),
  105. )
  106. @overload
  107. def call(
  108. self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
  109. ) -> VarOperationCall[[V1], R]: ...
  110. @overload
  111. def call(
  112. self: FunctionVar[ReflexCallable[[V1, V2], R]],
  113. arg1: Union[V1, Var[V1]],
  114. arg2: Union[V2, Var[V2]],
  115. ) -> VarOperationCall[[V1, V2], R]: ...
  116. @overload
  117. def call(
  118. self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
  119. arg1: Union[V1, Var[V1]],
  120. arg2: Union[V2, Var[V2]],
  121. arg3: Union[V3, Var[V3]],
  122. ) -> VarOperationCall[[V1, V2, V3], R]: ...
  123. @overload
  124. def call(
  125. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
  126. arg1: Union[V1, Var[V1]],
  127. arg2: Union[V2, Var[V2]],
  128. arg3: Union[V3, Var[V3]],
  129. arg4: Union[V4, Var[V4]],
  130. ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
  131. @overload
  132. def call(
  133. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
  134. arg1: Union[V1, Var[V1]],
  135. arg2: Union[V2, Var[V2]],
  136. arg3: Union[V3, Var[V3]],
  137. arg4: Union[V4, Var[V4]],
  138. arg5: Union[V5, Var[V5]],
  139. ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
  140. @overload
  141. def call(
  142. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
  143. arg1: Union[V1, Var[V1]],
  144. arg2: Union[V2, Var[V2]],
  145. arg3: Union[V3, Var[V3]],
  146. arg4: Union[V4, Var[V4]],
  147. arg5: Union[V5, Var[V5]],
  148. arg6: Union[V6, Var[V6]],
  149. ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
  150. @overload
  151. def call(
  152. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  153. ) -> VarOperationCall[P, R]: ...
  154. @overload
  155. def call(self, *args: Var | Any) -> Var: ...
  156. def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload]
  157. """Call the function with the given arguments.
  158. Args:
  159. *args: The arguments to call the function with.
  160. Returns:
  161. The function call operation.
  162. """
  163. return VarOperationCall.create(self, *args).guess_type()
  164. __call__ = call
  165. class BuilderFunctionVar(
  166. FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
  167. ):
  168. """Base class for immutable function vars with the builder pattern."""
  169. __call__ = FunctionVar.partial
  170. class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
  171. """Base class for immutable function vars from a string."""
  172. @classmethod
  173. def create(
  174. cls,
  175. func: str,
  176. _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
  177. _var_data: VarData | None = None,
  178. ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
  179. """Create a new function var from a string.
  180. Args:
  181. func: The function to call.
  182. _var_type: The type of the Var.
  183. _var_data: Additional hooks and imports associated with the Var.
  184. Returns:
  185. The function var.
  186. """
  187. return FunctionStringVar(
  188. _js_expr=func,
  189. _var_type=_var_type,
  190. _var_data=_var_data,
  191. )
  192. @dataclasses.dataclass(
  193. eq=False,
  194. frozen=True,
  195. slots=True,
  196. )
  197. class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
  198. """Base class for immutable vars that are the result of a function call."""
  199. _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
  200. _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
  201. @cached_property_no_lock
  202. def _cached_var_name(self) -> str:
  203. """The name of the var.
  204. Returns:
  205. The name of the var.
  206. """
  207. return f"({self._func!s}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
  208. @cached_property_no_lock
  209. def _cached_get_all_var_data(self) -> VarData | None:
  210. """Get all the var data associated with the var.
  211. Returns:
  212. All the var data associated with the var.
  213. """
  214. return VarData.merge(
  215. self._func._get_all_var_data() if self._func is not None else None,
  216. *[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
  217. self._var_data,
  218. )
  219. @classmethod
  220. def create(
  221. cls,
  222. func: FunctionVar[ReflexCallable[P, R]],
  223. *args: Var | Any,
  224. _var_type: GenericType = Any,
  225. _var_data: VarData | None = None,
  226. ) -> VarOperationCall:
  227. """Create a new function call var.
  228. Args:
  229. func: The function to call.
  230. *args: The arguments to call the function with.
  231. _var_type: The type of the Var.
  232. _var_data: Additional hooks and imports associated with the Var.
  233. Returns:
  234. The function call var.
  235. """
  236. function_return_type = (
  237. func._var_type.__args__[1]
  238. if getattr(func._var_type, "__args__", None)
  239. else Any
  240. )
  241. var_type = _var_type if _var_type is not Any else function_return_type
  242. return cls(
  243. _js_expr="",
  244. _var_type=var_type,
  245. _var_data=_var_data,
  246. _func=func,
  247. _args=args,
  248. )
  249. @dataclasses.dataclass(frozen=True)
  250. class DestructuredArg:
  251. """Class for destructured arguments."""
  252. fields: Tuple[str, ...] = ()
  253. rest: Optional[str] = None
  254. def to_javascript(self) -> str:
  255. """Convert the destructured argument to JavaScript.
  256. Returns:
  257. The destructured argument in JavaScript.
  258. """
  259. return format.wrap(
  260. ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
  261. "{",
  262. "}",
  263. )
  264. @dataclasses.dataclass(
  265. frozen=True,
  266. )
  267. class FunctionArgs:
  268. """Class for function arguments."""
  269. args: Tuple[Union[str, DestructuredArg], ...] = ()
  270. rest: Optional[str] = None
  271. def format_args_function_operation(
  272. args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
  273. ) -> str:
  274. """Format an args function operation.
  275. Args:
  276. args: The function arguments.
  277. return_expr: The return expression.
  278. explicit_return: Whether to use explicit return syntax.
  279. Returns:
  280. The formatted args function operation.
  281. """
  282. arg_names_str = ", ".join(
  283. [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
  284. ) + (f", ...{args.rest}" if args.rest else "")
  285. return_expr_str = str(LiteralVar.create(return_expr))
  286. # Wrap return expression in curly braces if explicit return syntax is used.
  287. return_expr_str_wrapped = (
  288. format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
  289. )
  290. return f"(({arg_names_str}) => {return_expr_str_wrapped})"
  291. @dataclasses.dataclass(
  292. eq=False,
  293. frozen=True,
  294. slots=True,
  295. )
  296. class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
  297. """Base class for immutable function defined via arguments and return expression."""
  298. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  299. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  300. _explicit_return: bool = dataclasses.field(default=False)
  301. @cached_property_no_lock
  302. def _cached_var_name(self) -> str:
  303. """The name of the var.
  304. Returns:
  305. The name of the var.
  306. """
  307. return format_args_function_operation(
  308. self._args, self._return_expr, self._explicit_return
  309. )
  310. @classmethod
  311. def create(
  312. cls,
  313. args_names: Sequence[Union[str, DestructuredArg]],
  314. return_expr: Var | Any,
  315. rest: str | None = None,
  316. explicit_return: bool = False,
  317. _var_type: GenericType = Callable,
  318. _var_data: VarData | None = None,
  319. ):
  320. """Create a new function var.
  321. Args:
  322. args_names: The names of the arguments.
  323. return_expr: The return expression of the function.
  324. rest: The name of the rest argument.
  325. explicit_return: Whether to use explicit return syntax.
  326. _var_type: The type of the Var.
  327. _var_data: Additional hooks and imports associated with the Var.
  328. Returns:
  329. The function var.
  330. """
  331. return_expr = Var.create(return_expr)
  332. return cls(
  333. _js_expr="",
  334. _var_type=_var_type,
  335. _var_data=_var_data,
  336. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  337. _return_expr=return_expr,
  338. _explicit_return=explicit_return,
  339. )
  340. @dataclasses.dataclass(
  341. eq=False,
  342. frozen=True,
  343. slots=True,
  344. )
  345. class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
  346. """Base class for immutable function defined via arguments and return expression with the builder pattern."""
  347. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  348. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  349. _explicit_return: bool = dataclasses.field(default=False)
  350. @cached_property_no_lock
  351. def _cached_var_name(self) -> str:
  352. """The name of the var.
  353. Returns:
  354. The name of the var.
  355. """
  356. return format_args_function_operation(
  357. self._args, self._return_expr, self._explicit_return
  358. )
  359. @classmethod
  360. def create(
  361. cls,
  362. args_names: Sequence[Union[str, DestructuredArg]],
  363. return_expr: Var | Any,
  364. rest: str | None = None,
  365. explicit_return: bool = False,
  366. _var_type: GenericType = Callable,
  367. _var_data: VarData | None = None,
  368. ):
  369. """Create a new function var.
  370. Args:
  371. args_names: The names of the arguments.
  372. return_expr: The return expression of the function.
  373. rest: The name of the rest argument.
  374. explicit_return: Whether to use explicit return syntax.
  375. _var_type: The type of the Var.
  376. _var_data: Additional hooks and imports associated with the Var.
  377. Returns:
  378. The function var.
  379. """
  380. return_expr = Var.create(return_expr)
  381. return cls(
  382. _js_expr="",
  383. _var_type=_var_type,
  384. _var_data=_var_data,
  385. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  386. _return_expr=return_expr,
  387. _explicit_return=explicit_return,
  388. )
  389. if python_version := sys.version_info[:2] >= (3, 10):
  390. JSON_STRINGIFY = FunctionStringVar.create(
  391. "JSON.stringify", _var_type=ReflexCallable[[Any], str]
  392. )
  393. ARRAY_ISARRAY = FunctionStringVar.create(
  394. "Array.isArray", _var_type=ReflexCallable[[Any], bool]
  395. )
  396. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  397. "((__to_string) => __to_string.toString())",
  398. _var_type=ReflexCallable[[Any], str],
  399. )
  400. else:
  401. JSON_STRINGIFY = FunctionStringVar.create(
  402. "JSON.stringify", _var_type=ReflexCallable[Any, str]
  403. )
  404. ARRAY_ISARRAY = FunctionStringVar.create(
  405. "Array.isArray", _var_type=ReflexCallable[Any, bool]
  406. )
  407. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  408. "((__to_string) => __to_string.toString())",
  409. _var_type=ReflexCallable[Any, str],
  410. )