function.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. """Immutable function vars."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import sys
  5. from typing_extensions import (
  6. Any,
  7. Callable,
  8. Concatenate,
  9. Generic,
  10. Optional,
  11. ParamSpec,
  12. Protocol,
  13. Sequence,
  14. Tuple,
  15. Type,
  16. TypeVar,
  17. Union,
  18. overload,
  19. )
  20. from reflex.utils import format
  21. from reflex.utils.types import GenericType
  22. from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
  23. P = ParamSpec("P")
  24. V1 = TypeVar("V1")
  25. V2 = TypeVar("V2")
  26. V3 = TypeVar("V3")
  27. V4 = TypeVar("V4")
  28. V5 = TypeVar("V5")
  29. V6 = TypeVar("V6")
  30. R = TypeVar("R")
  31. class ReflexCallable(Protocol[P, R]):
  32. """Protocol for a callable."""
  33. __call__: Callable[P, R]
  34. CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
  35. OTHER_CALLABLE_TYPE = TypeVar(
  36. "OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True
  37. )
  38. class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
  39. """Base class for immutable function vars."""
  40. @overload
  41. def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
  42. @overload
  43. def partial(
  44. self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
  45. arg1: Union[V1, Var[V1]],
  46. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  47. @overload
  48. def partial(
  49. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
  50. arg1: Union[V1, Var[V1]],
  51. arg2: Union[V2, Var[V2]],
  52. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  53. @overload
  54. def partial(
  55. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
  56. arg1: Union[V1, Var[V1]],
  57. arg2: Union[V2, Var[V2]],
  58. arg3: Union[V3, Var[V3]],
  59. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  60. @overload
  61. def partial(
  62. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
  63. arg1: Union[V1, Var[V1]],
  64. arg2: Union[V2, Var[V2]],
  65. arg3: Union[V3, Var[V3]],
  66. arg4: Union[V4, Var[V4]],
  67. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  68. @overload
  69. def partial(
  70. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
  71. arg1: Union[V1, Var[V1]],
  72. arg2: Union[V2, Var[V2]],
  73. arg3: Union[V3, Var[V3]],
  74. arg4: Union[V4, Var[V4]],
  75. arg5: Union[V5, Var[V5]],
  76. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  77. @overload
  78. def partial(
  79. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
  80. arg1: Union[V1, Var[V1]],
  81. arg2: Union[V2, Var[V2]],
  82. arg3: Union[V3, Var[V3]],
  83. arg4: Union[V4, Var[V4]],
  84. arg5: Union[V5, Var[V5]],
  85. arg6: Union[V6, Var[V6]],
  86. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  87. @overload
  88. def partial(
  89. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  90. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  91. @overload
  92. def partial(self, *args: Var | Any) -> FunctionVar: ...
  93. def partial(self, *args: Var | Any) -> FunctionVar: # type: ignore
  94. """Partially apply the function with the given arguments.
  95. Args:
  96. *args: The arguments to partially apply the function with.
  97. Returns:
  98. The partially applied function.
  99. """
  100. if not args:
  101. return ArgsFunctionOperation.create((), self)
  102. return ArgsFunctionOperation.create(
  103. ("...args",),
  104. VarOperationCall.create(self, *args, Var(_js_expr="...args")),
  105. )
  106. @overload
  107. def call(
  108. self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
  109. ) -> VarOperationCall[[V1], R]: ...
  110. @overload
  111. def call(
  112. self: FunctionVar[ReflexCallable[[V1, V2], R]],
  113. arg1: Union[V1, Var[V1]],
  114. arg2: Union[V2, Var[V2]],
  115. ) -> VarOperationCall[[V1, V2], R]: ...
  116. @overload
  117. def call(
  118. self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
  119. arg1: Union[V1, Var[V1]],
  120. arg2: Union[V2, Var[V2]],
  121. arg3: Union[V3, Var[V3]],
  122. ) -> VarOperationCall[[V1, V2, V3], R]: ...
  123. @overload
  124. def call(
  125. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
  126. arg1: Union[V1, Var[V1]],
  127. arg2: Union[V2, Var[V2]],
  128. arg3: Union[V3, Var[V3]],
  129. arg4: Union[V4, Var[V4]],
  130. ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
  131. @overload
  132. def call(
  133. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
  134. arg1: Union[V1, Var[V1]],
  135. arg2: Union[V2, Var[V2]],
  136. arg3: Union[V3, Var[V3]],
  137. arg4: Union[V4, Var[V4]],
  138. arg5: Union[V5, Var[V5]],
  139. ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
  140. @overload
  141. def call(
  142. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
  143. arg1: Union[V1, Var[V1]],
  144. arg2: Union[V2, Var[V2]],
  145. arg3: Union[V3, Var[V3]],
  146. arg4: Union[V4, Var[V4]],
  147. arg5: Union[V5, Var[V5]],
  148. arg6: Union[V6, Var[V6]],
  149. ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
  150. @overload
  151. def call(
  152. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  153. ) -> VarOperationCall[P, R]: ...
  154. @overload
  155. def call(self, *args: Var | Any) -> Var: ...
  156. def call(self, *args: Var | Any) -> Var: # type: ignore
  157. """Call the function with the given arguments.
  158. Args:
  159. *args: The arguments to call the function with.
  160. Returns:
  161. The function call operation.
  162. """
  163. return VarOperationCall.create(self, *args).guess_type()
  164. __call__ = call
  165. class BuilderFunctionVar(
  166. FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
  167. ):
  168. """Base class for immutable function vars with the builder pattern."""
  169. __call__ = FunctionVar.partial
  170. class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
  171. """Base class for immutable function vars from a string."""
  172. @classmethod
  173. def create(
  174. cls,
  175. func: str,
  176. _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
  177. _var_data: VarData | None = None,
  178. ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
  179. """Create a new function var from a string.
  180. Args:
  181. func: The function to call.
  182. _var_data: Additional hooks and imports associated with the Var.
  183. Returns:
  184. The function var.
  185. """
  186. return FunctionStringVar(
  187. _js_expr=func,
  188. _var_type=_var_type,
  189. _var_data=_var_data,
  190. )
  191. @dataclasses.dataclass(
  192. eq=False,
  193. frozen=True,
  194. **{"slots": True} if sys.version_info >= (3, 10) else {},
  195. )
  196. class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
  197. """Base class for immutable vars that are the result of a function call."""
  198. _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
  199. _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
  200. @cached_property_no_lock
  201. def _cached_var_name(self) -> str:
  202. """The name of the var.
  203. Returns:
  204. The name of the var.
  205. """
  206. return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
  207. @cached_property_no_lock
  208. def _cached_get_all_var_data(self) -> VarData | None:
  209. """Get all the var data associated with the var.
  210. Returns:
  211. All the var data associated with the var.
  212. """
  213. return VarData.merge(
  214. self._func._get_all_var_data() if self._func is not None else None,
  215. *[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
  216. self._var_data,
  217. )
  218. @classmethod
  219. def create(
  220. cls,
  221. func: FunctionVar[ReflexCallable[P, R]],
  222. *args: Var | Any,
  223. _var_type: GenericType = Any,
  224. _var_data: VarData | None = None,
  225. ) -> VarOperationCall:
  226. """Create a new function call var.
  227. Args:
  228. func: The function to call.
  229. *args: The arguments to call the function with.
  230. _var_data: Additional hooks and imports associated with the Var.
  231. Returns:
  232. The function call var.
  233. """
  234. function_return_type = (
  235. func._var_type.__args__[1]
  236. if getattr(func._var_type, "__args__", None)
  237. else Any
  238. )
  239. var_type = _var_type if _var_type is not Any else function_return_type
  240. return cls(
  241. _js_expr="",
  242. _var_type=var_type,
  243. _var_data=_var_data,
  244. _func=func,
  245. _args=args,
  246. )
  247. @dataclasses.dataclass(frozen=True)
  248. class DestructuredArg:
  249. """Class for destructured arguments."""
  250. fields: Tuple[str, ...] = tuple()
  251. rest: Optional[str] = None
  252. def to_javascript(self) -> str:
  253. """Convert the destructured argument to JavaScript.
  254. Returns:
  255. The destructured argument in JavaScript.
  256. """
  257. return format.wrap(
  258. ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
  259. "{",
  260. "}",
  261. )
  262. @dataclasses.dataclass(
  263. frozen=True,
  264. )
  265. class FunctionArgs:
  266. """Class for function arguments."""
  267. args: Tuple[Union[str, DestructuredArg], ...] = tuple()
  268. rest: Optional[str] = None
  269. def format_args_function_operation(
  270. args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
  271. ) -> str:
  272. """Format an args function operation.
  273. Args:
  274. args: The function arguments.
  275. return_expr: The return expression.
  276. explicit_return: Whether to use explicit return syntax.
  277. Returns:
  278. The formatted args function operation.
  279. """
  280. arg_names_str = ", ".join(
  281. [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
  282. ) + (f", ...{args.rest}" if args.rest else "")
  283. return_expr_str = str(LiteralVar.create(return_expr))
  284. # Wrap return expression in curly braces if explicit return syntax is used.
  285. return_expr_str_wrapped = (
  286. format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
  287. )
  288. return f"(({arg_names_str}) => {return_expr_str_wrapped})"
  289. @dataclasses.dataclass(
  290. eq=False,
  291. frozen=True,
  292. **{"slots": True} if sys.version_info >= (3, 10) else {},
  293. )
  294. class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
  295. """Base class for immutable function defined via arguments and return expression."""
  296. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  297. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  298. _explicit_return: bool = dataclasses.field(default=False)
  299. @cached_property_no_lock
  300. def _cached_var_name(self) -> str:
  301. """The name of the var.
  302. Returns:
  303. The name of the var.
  304. """
  305. return format_args_function_operation(
  306. self._args, self._return_expr, self._explicit_return
  307. )
  308. @classmethod
  309. def create(
  310. cls,
  311. args_names: Sequence[Union[str, DestructuredArg]],
  312. return_expr: Var | Any,
  313. rest: str | None = None,
  314. explicit_return: bool = False,
  315. _var_type: GenericType = Callable,
  316. _var_data: VarData | None = None,
  317. ):
  318. """Create a new function var.
  319. Args:
  320. args_names: The names of the arguments.
  321. return_expr: The return expression of the function.
  322. rest: The name of the rest argument.
  323. explicit_return: Whether to use explicit return syntax.
  324. _var_data: Additional hooks and imports associated with the Var.
  325. Returns:
  326. The function var.
  327. """
  328. return cls(
  329. _js_expr="",
  330. _var_type=_var_type,
  331. _var_data=_var_data,
  332. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  333. _return_expr=return_expr,
  334. _explicit_return=explicit_return,
  335. )
  336. @dataclasses.dataclass(
  337. eq=False,
  338. frozen=True,
  339. **{"slots": True} if sys.version_info >= (3, 10) else {},
  340. )
  341. class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
  342. """Base class for immutable function defined via arguments and return expression with the builder pattern."""
  343. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  344. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  345. _explicit_return: bool = dataclasses.field(default=False)
  346. @cached_property_no_lock
  347. def _cached_var_name(self) -> str:
  348. """The name of the var.
  349. Returns:
  350. The name of the var.
  351. """
  352. return format_args_function_operation(
  353. self._args, self._return_expr, self._explicit_return
  354. )
  355. @classmethod
  356. def create(
  357. cls,
  358. args_names: Sequence[Union[str, DestructuredArg]],
  359. return_expr: Var | Any,
  360. rest: str | None = None,
  361. explicit_return: bool = False,
  362. _var_type: GenericType = Callable,
  363. _var_data: VarData | None = None,
  364. ):
  365. """Create a new function var.
  366. Args:
  367. args_names: The names of the arguments.
  368. return_expr: The return expression of the function.
  369. rest: The name of the rest argument.
  370. explicit_return: Whether to use explicit return syntax.
  371. _var_data: Additional hooks and imports associated with the Var.
  372. Returns:
  373. The function var.
  374. """
  375. return cls(
  376. _js_expr="",
  377. _var_type=_var_type,
  378. _var_data=_var_data,
  379. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  380. _return_expr=return_expr,
  381. _explicit_return=explicit_return,
  382. )
  383. if python_version := sys.version_info[:2] >= (3, 10):
  384. JSON_STRINGIFY = FunctionStringVar.create(
  385. "JSON.stringify", _var_type=ReflexCallable[[Any], str]
  386. )
  387. ARRAY_ISARRAY = FunctionStringVar.create(
  388. "Array.isArray", _var_type=ReflexCallable[[Any], bool]
  389. )
  390. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  391. "((__to_string) => __to_string.toString())",
  392. _var_type=ReflexCallable[[Any], str],
  393. )
  394. else:
  395. JSON_STRINGIFY = FunctionStringVar.create(
  396. "JSON.stringify", _var_type=ReflexCallable[Any, str]
  397. )
  398. ARRAY_ISARRAY = FunctionStringVar.create(
  399. "Array.isArray", _var_type=ReflexCallable[Any, bool]
  400. )
  401. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  402. "((__to_string) => __to_string.toString())",
  403. _var_type=ReflexCallable[Any, str],
  404. )