function.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. """Immutable function vars."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import sys
  5. from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
  6. from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
  7. from reflex.utils import format
  8. from reflex.utils.types import GenericType
  9. from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
  10. P = ParamSpec("P")
  11. V1 = TypeVar("V1")
  12. V2 = TypeVar("V2")
  13. V3 = TypeVar("V3")
  14. V4 = TypeVar("V4")
  15. V5 = TypeVar("V5")
  16. V6 = TypeVar("V6")
  17. R = TypeVar("R")
  18. class ReflexCallable(Protocol[P, R]):
  19. """Protocol for a callable."""
  20. __call__: Callable[P, R]
  21. CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
  22. OTHER_CALLABLE_TYPE = TypeVar(
  23. "OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True
  24. )
  25. class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
  26. """Base class for immutable function vars."""
  27. @overload
  28. def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
  29. @overload
  30. def partial(
  31. self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
  32. arg1: Union[V1, Var[V1]],
  33. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  34. @overload
  35. def partial(
  36. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
  37. arg1: Union[V1, Var[V1]],
  38. arg2: Union[V2, Var[V2]],
  39. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  40. @overload
  41. def partial(
  42. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
  43. arg1: Union[V1, Var[V1]],
  44. arg2: Union[V2, Var[V2]],
  45. arg3: Union[V3, Var[V3]],
  46. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  47. @overload
  48. def partial(
  49. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
  50. arg1: Union[V1, Var[V1]],
  51. arg2: Union[V2, Var[V2]],
  52. arg3: Union[V3, Var[V3]],
  53. arg4: Union[V4, Var[V4]],
  54. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  55. @overload
  56. def partial(
  57. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
  58. arg1: Union[V1, Var[V1]],
  59. arg2: Union[V2, Var[V2]],
  60. arg3: Union[V3, Var[V3]],
  61. arg4: Union[V4, Var[V4]],
  62. arg5: Union[V5, Var[V5]],
  63. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  64. @overload
  65. def partial(
  66. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
  67. arg1: Union[V1, Var[V1]],
  68. arg2: Union[V2, Var[V2]],
  69. arg3: Union[V3, Var[V3]],
  70. arg4: Union[V4, Var[V4]],
  71. arg5: Union[V5, Var[V5]],
  72. arg6: Union[V6, Var[V6]],
  73. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  74. @overload
  75. def partial(
  76. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  77. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  78. @overload
  79. def partial(self, *args: Var | Any) -> FunctionVar: ...
  80. def partial(self, *args: Var | Any) -> FunctionVar: # type: ignore
  81. """Partially apply the function with the given arguments.
  82. Args:
  83. *args: The arguments to partially apply the function with.
  84. Returns:
  85. The partially applied function.
  86. """
  87. self._pre_check(*args)
  88. if not args:
  89. return ArgsFunctionOperation.create((), self)
  90. return ArgsFunctionOperation.create(
  91. ("...args",),
  92. VarOperationCall.create(self, *args, Var(_js_expr="...args")),
  93. )
  94. @overload
  95. def call(
  96. self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
  97. ) -> VarOperationCall[[V1], R]: ...
  98. @overload
  99. def call(
  100. self: FunctionVar[ReflexCallable[[V1, V2], R]],
  101. arg1: Union[V1, Var[V1]],
  102. arg2: Union[V2, Var[V2]],
  103. ) -> VarOperationCall[[V1, V2], R]: ...
  104. @overload
  105. def call(
  106. self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
  107. arg1: Union[V1, Var[V1]],
  108. arg2: Union[V2, Var[V2]],
  109. arg3: Union[V3, Var[V3]],
  110. ) -> VarOperationCall[[V1, V2, V3], R]: ...
  111. @overload
  112. def call(
  113. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
  114. arg1: Union[V1, Var[V1]],
  115. arg2: Union[V2, Var[V2]],
  116. arg3: Union[V3, Var[V3]],
  117. arg4: Union[V4, Var[V4]],
  118. ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
  119. @overload
  120. def call(
  121. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
  122. arg1: Union[V1, Var[V1]],
  123. arg2: Union[V2, Var[V2]],
  124. arg3: Union[V3, Var[V3]],
  125. arg4: Union[V4, Var[V4]],
  126. arg5: Union[V5, Var[V5]],
  127. ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
  128. @overload
  129. def call(
  130. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
  131. arg1: Union[V1, Var[V1]],
  132. arg2: Union[V2, Var[V2]],
  133. arg3: Union[V3, Var[V3]],
  134. arg4: Union[V4, Var[V4]],
  135. arg5: Union[V5, Var[V5]],
  136. arg6: Union[V6, Var[V6]],
  137. ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
  138. @overload
  139. def call(
  140. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  141. ) -> VarOperationCall[P, R]: ...
  142. @overload
  143. def call(self, *args: Var | Any) -> Var: ...
  144. def call(self, *args: Var | Any) -> Var: # type: ignore
  145. """Call the function with the given arguments.
  146. Args:
  147. *args: The arguments to call the function with.
  148. Returns:
  149. The function call operation.
  150. """
  151. self._pre_check(*args)
  152. return VarOperationCall.create(self, *args).guess_type()
  153. def _pre_check(self, *args: Var | Any) -> bool:
  154. """Check if the function can be called with the given arguments.
  155. Args:
  156. *args: The arguments to call the function with.
  157. Returns:
  158. True if the function can be called with the given arguments.
  159. """
  160. return True
  161. __call__ = call
  162. class BuilderFunctionVar(
  163. FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
  164. ):
  165. """Base class for immutable function vars with the builder pattern."""
  166. __call__ = FunctionVar.partial
  167. class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
  168. """Base class for immutable function vars from a string."""
  169. @classmethod
  170. def create(
  171. cls,
  172. func: str,
  173. _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
  174. _var_data: VarData | None = None,
  175. ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
  176. """Create a new function var from a string.
  177. Args:
  178. func: The function to call.
  179. _var_data: Additional hooks and imports associated with the Var.
  180. Returns:
  181. The function var.
  182. """
  183. return FunctionStringVar(
  184. _js_expr=func,
  185. _var_type=_var_type,
  186. _var_data=_var_data,
  187. )
  188. @dataclasses.dataclass(
  189. eq=False,
  190. frozen=True,
  191. **{"slots": True} if sys.version_info >= (3, 10) else {},
  192. )
  193. class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
  194. """Base class for immutable vars that are the result of a function call."""
  195. _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
  196. _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
  197. @cached_property_no_lock
  198. def _cached_var_name(self) -> str:
  199. """The name of the var.
  200. Returns:
  201. The name of the var.
  202. """
  203. return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
  204. @cached_property_no_lock
  205. def _cached_get_all_var_data(self) -> VarData | None:
  206. """Get all the var data associated with the var.
  207. Returns:
  208. All the var data associated with the var.
  209. """
  210. return VarData.merge(
  211. self._func._get_all_var_data() if self._func is not None else None,
  212. *[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
  213. self._var_data,
  214. )
  215. @classmethod
  216. def create(
  217. cls,
  218. func: FunctionVar[ReflexCallable[P, R]],
  219. *args: Var | Any,
  220. _var_type: GenericType = Any,
  221. _var_data: VarData | None = None,
  222. ) -> VarOperationCall:
  223. """Create a new function call var.
  224. Args:
  225. func: The function to call.
  226. *args: The arguments to call the function with.
  227. _var_data: Additional hooks and imports associated with the Var.
  228. Returns:
  229. The function call var.
  230. """
  231. function_return_type = (
  232. func._var_type.__args__[1]
  233. if getattr(func._var_type, "__args__", None)
  234. else Any
  235. )
  236. var_type = _var_type if _var_type is not Any else function_return_type
  237. return cls(
  238. _js_expr="",
  239. _var_type=var_type,
  240. _var_data=_var_data,
  241. _func=func,
  242. _args=args,
  243. )
  244. @dataclasses.dataclass(frozen=True)
  245. class DestructuredArg:
  246. """Class for destructured arguments."""
  247. fields: Tuple[str, ...] = tuple()
  248. rest: Optional[str] = None
  249. def to_javascript(self) -> str:
  250. """Convert the destructured argument to JavaScript.
  251. Returns:
  252. The destructured argument in JavaScript.
  253. """
  254. return format.wrap(
  255. ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
  256. "{",
  257. "}",
  258. )
  259. @dataclasses.dataclass(
  260. frozen=True,
  261. )
  262. class FunctionArgs:
  263. """Class for function arguments."""
  264. args: Tuple[Union[str, DestructuredArg], ...] = tuple()
  265. rest: Optional[str] = None
  266. def format_args_function_operation(
  267. args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
  268. ) -> str:
  269. """Format an args function operation.
  270. Args:
  271. args: The function arguments.
  272. return_expr: The return expression.
  273. explicit_return: Whether to use explicit return syntax.
  274. Returns:
  275. The formatted args function operation.
  276. """
  277. arg_names_str = ", ".join(
  278. [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
  279. ) + (f", ...{args.rest}" if args.rest else "")
  280. return_expr_str = str(LiteralVar.create(return_expr))
  281. # Wrap return expression in curly braces if explicit return syntax is used.
  282. return_expr_str_wrapped = (
  283. format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
  284. )
  285. return f"(({arg_names_str}) => {return_expr_str_wrapped})"
  286. @dataclasses.dataclass(
  287. eq=False,
  288. frozen=True,
  289. **{"slots": True} if sys.version_info >= (3, 10) else {},
  290. )
  291. class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
  292. """Base class for immutable function defined via arguments and return expression."""
  293. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  294. _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
  295. default_factory=tuple
  296. )
  297. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  298. _explicit_return: bool = dataclasses.field(default=False)
  299. @cached_property_no_lock
  300. def _cached_var_name(self) -> str:
  301. """The name of the var.
  302. Returns:
  303. The name of the var.
  304. """
  305. return format_args_function_operation(
  306. self._args, self._return_expr, self._explicit_return
  307. )
  308. def _pre_check(self, *args: Var | Any) -> bool:
  309. """Check if the function can be called with the given arguments.
  310. Args:
  311. *args: The arguments to call the function with.
  312. Returns:
  313. True if the function can be called with the given arguments.
  314. """
  315. return all(
  316. validator(arg)
  317. for validator, arg in zip(self._validators, args, strict=False)
  318. )
  319. @classmethod
  320. def create(
  321. cls,
  322. args_names: Sequence[Union[str, DestructuredArg]],
  323. return_expr: Var | Any,
  324. rest: str | None = None,
  325. validators: Sequence[Callable[[Any], bool]] = (),
  326. explicit_return: bool = False,
  327. _var_type: GenericType = Callable,
  328. _var_data: VarData | None = None,
  329. ):
  330. """Create a new function var.
  331. Args:
  332. args_names: The names of the arguments.
  333. return_expr: The return expression of the function.
  334. rest: The name of the rest argument.
  335. explicit_return: Whether to use explicit return syntax.
  336. _var_data: Additional hooks and imports associated with the Var.
  337. Returns:
  338. The function var.
  339. """
  340. return cls(
  341. _js_expr="",
  342. _var_type=_var_type,
  343. _var_data=_var_data,
  344. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  345. _validators=tuple(validators),
  346. _return_expr=return_expr,
  347. _explicit_return=explicit_return,
  348. )
  349. @dataclasses.dataclass(
  350. eq=False,
  351. frozen=True,
  352. **{"slots": True} if sys.version_info >= (3, 10) else {},
  353. )
  354. class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
  355. """Base class for immutable function defined via arguments and return expression with the builder pattern."""
  356. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  357. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  358. _explicit_return: bool = dataclasses.field(default=False)
  359. @cached_property_no_lock
  360. def _cached_var_name(self) -> str:
  361. """The name of the var.
  362. Returns:
  363. The name of the var.
  364. """
  365. return format_args_function_operation(
  366. self._args, self._return_expr, self._explicit_return
  367. )
  368. @classmethod
  369. def create(
  370. cls,
  371. args_names: Sequence[Union[str, DestructuredArg]],
  372. return_expr: Var | Any,
  373. rest: str | None = None,
  374. explicit_return: bool = False,
  375. _var_type: GenericType = Callable,
  376. _var_data: VarData | None = None,
  377. ):
  378. """Create a new function var.
  379. Args:
  380. args_names: The names of the arguments.
  381. return_expr: The return expression of the function.
  382. rest: The name of the rest argument.
  383. explicit_return: Whether to use explicit return syntax.
  384. _var_data: Additional hooks and imports associated with the Var.
  385. Returns:
  386. The function var.
  387. """
  388. return cls(
  389. _js_expr="",
  390. _var_type=_var_type,
  391. _var_data=_var_data,
  392. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  393. _return_expr=return_expr,
  394. _explicit_return=explicit_return,
  395. )
  396. if python_version := sys.version_info[:2] >= (3, 10):
  397. JSON_STRINGIFY = FunctionStringVar.create(
  398. "JSON.stringify", _var_type=ReflexCallable[[Any], str]
  399. )
  400. ARRAY_ISARRAY = FunctionStringVar.create(
  401. "Array.isArray", _var_type=ReflexCallable[[Any], bool]
  402. )
  403. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  404. "((__to_string) => __to_string.toString())",
  405. _var_type=ReflexCallable[[Any], str],
  406. )
  407. else:
  408. JSON_STRINGIFY = FunctionStringVar.create(
  409. "JSON.stringify", _var_type=ReflexCallable[Any, str]
  410. )
  411. ARRAY_ISARRAY = FunctionStringVar.create(
  412. "Array.isArray", _var_type=ReflexCallable[Any, bool]
  413. )
  414. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  415. "((__to_string) => __to_string.toString())",
  416. _var_type=ReflexCallable[Any, str],
  417. )