function.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. """Immutable function vars."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import sys
  5. from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
  6. from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
  7. from reflex.utils import format
  8. from reflex.utils.exceptions import VarTypeError
  9. from reflex.utils.types import GenericType
  10. from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
  11. P = ParamSpec("P")
  12. V1 = TypeVar("V1")
  13. V2 = TypeVar("V2")
  14. V3 = TypeVar("V3")
  15. V4 = TypeVar("V4")
  16. V5 = TypeVar("V5")
  17. V6 = TypeVar("V6")
  18. R = TypeVar("R")
  19. class ReflexCallable(Protocol[P, R]):
  20. """Protocol for a callable."""
  21. __call__: Callable[P, R]
  22. CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
  23. OTHER_CALLABLE_TYPE = TypeVar(
  24. "OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True
  25. )
  26. class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
  27. """Base class for immutable function vars."""
  28. @overload
  29. def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
  30. @overload
  31. def partial(
  32. self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
  33. arg1: Union[V1, Var[V1]],
  34. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  35. @overload
  36. def partial(
  37. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
  38. arg1: Union[V1, Var[V1]],
  39. arg2: Union[V2, Var[V2]],
  40. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  41. @overload
  42. def partial(
  43. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
  44. arg1: Union[V1, Var[V1]],
  45. arg2: Union[V2, Var[V2]],
  46. arg3: Union[V3, Var[V3]],
  47. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  48. @overload
  49. def partial(
  50. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
  51. arg1: Union[V1, Var[V1]],
  52. arg2: Union[V2, Var[V2]],
  53. arg3: Union[V3, Var[V3]],
  54. arg4: Union[V4, Var[V4]],
  55. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  56. @overload
  57. def partial(
  58. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
  59. arg1: Union[V1, Var[V1]],
  60. arg2: Union[V2, Var[V2]],
  61. arg3: Union[V3, Var[V3]],
  62. arg4: Union[V4, Var[V4]],
  63. arg5: Union[V5, Var[V5]],
  64. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  65. @overload
  66. def partial(
  67. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
  68. arg1: Union[V1, Var[V1]],
  69. arg2: Union[V2, Var[V2]],
  70. arg3: Union[V3, Var[V3]],
  71. arg4: Union[V4, Var[V4]],
  72. arg5: Union[V5, Var[V5]],
  73. arg6: Union[V6, Var[V6]],
  74. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  75. @overload
  76. def partial(
  77. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  78. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  79. @overload
  80. def partial(self, *args: Var | Any) -> FunctionVar: ...
  81. def partial(self, *args: Var | Any) -> FunctionVar: # type: ignore
  82. """Partially apply the function with the given arguments.
  83. Args:
  84. *args: The arguments to partially apply the function with.
  85. Returns:
  86. The partially applied function.
  87. """
  88. if not args:
  89. return self
  90. remaining_validators = self._pre_check(*args)
  91. if self.__call__ is self.partial:
  92. # if the default behavior is partial, we should return a new partial function
  93. return ArgsFunctionOperationBuilder.create(
  94. (),
  95. VarOperationCall.create(self, *args, Var(_js_expr="...args")),
  96. rest="args",
  97. validators=remaining_validators,
  98. )
  99. return ArgsFunctionOperation.create(
  100. (),
  101. VarOperationCall.create(self, *args, Var(_js_expr="...args")),
  102. rest="args",
  103. validators=remaining_validators,
  104. )
  105. @overload
  106. def call(
  107. self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
  108. ) -> VarOperationCall[[V1], R]: ...
  109. @overload
  110. def call(
  111. self: FunctionVar[ReflexCallable[[V1, V2], R]],
  112. arg1: Union[V1, Var[V1]],
  113. arg2: Union[V2, Var[V2]],
  114. ) -> VarOperationCall[[V1, V2], R]: ...
  115. @overload
  116. def call(
  117. self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
  118. arg1: Union[V1, Var[V1]],
  119. arg2: Union[V2, Var[V2]],
  120. arg3: Union[V3, Var[V3]],
  121. ) -> VarOperationCall[[V1, V2, V3], R]: ...
  122. @overload
  123. def call(
  124. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
  125. arg1: Union[V1, Var[V1]],
  126. arg2: Union[V2, Var[V2]],
  127. arg3: Union[V3, Var[V3]],
  128. arg4: Union[V4, Var[V4]],
  129. ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
  130. @overload
  131. def call(
  132. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
  133. arg1: Union[V1, Var[V1]],
  134. arg2: Union[V2, Var[V2]],
  135. arg3: Union[V3, Var[V3]],
  136. arg4: Union[V4, Var[V4]],
  137. arg5: Union[V5, Var[V5]],
  138. ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
  139. @overload
  140. def call(
  141. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
  142. arg1: Union[V1, Var[V1]],
  143. arg2: Union[V2, Var[V2]],
  144. arg3: Union[V3, Var[V3]],
  145. arg4: Union[V4, Var[V4]],
  146. arg5: Union[V5, Var[V5]],
  147. arg6: Union[V6, Var[V6]],
  148. ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
  149. @overload
  150. def call(
  151. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  152. ) -> VarOperationCall[P, R]: ...
  153. @overload
  154. def call(self, *args: Var | Any) -> Var: ...
  155. def call(self, *args: Var | Any) -> Var: # type: ignore
  156. """Call the function with the given arguments.
  157. Args:
  158. *args: The arguments to call the function with.
  159. Returns:
  160. The function call operation.
  161. """
  162. self._pre_check(*args)
  163. return VarOperationCall.create(self, *args).guess_type()
  164. def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
  165. """Check if the function can be called with the given arguments.
  166. Args:
  167. *args: The arguments to call the function with.
  168. Returns:
  169. True if the function can be called with the given arguments.
  170. """
  171. return tuple()
  172. __call__ = call
  173. class BuilderFunctionVar(
  174. FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
  175. ):
  176. """Base class for immutable function vars with the builder pattern."""
  177. __call__ = FunctionVar.partial
  178. class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
  179. """Base class for immutable function vars from a string."""
  180. @classmethod
  181. def create(
  182. cls,
  183. func: str,
  184. _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
  185. _var_data: VarData | None = None,
  186. ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
  187. """Create a new function var from a string.
  188. Args:
  189. func: The function to call.
  190. _var_data: Additional hooks and imports associated with the Var.
  191. Returns:
  192. The function var.
  193. """
  194. return FunctionStringVar(
  195. _js_expr=func,
  196. _var_type=_var_type,
  197. _var_data=_var_data,
  198. )
  199. @dataclasses.dataclass(
  200. eq=False,
  201. frozen=True,
  202. **{"slots": True} if sys.version_info >= (3, 10) else {},
  203. )
  204. class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
  205. """Base class for immutable vars that are the result of a function call."""
  206. _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
  207. _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
  208. @cached_property_no_lock
  209. def _cached_var_name(self) -> str:
  210. """The name of the var.
  211. Returns:
  212. The name of the var.
  213. """
  214. return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
  215. @cached_property_no_lock
  216. def _cached_get_all_var_data(self) -> VarData | None:
  217. """Get all the var data associated with the var.
  218. Returns:
  219. All the var data associated with the var.
  220. """
  221. return VarData.merge(
  222. self._func._get_all_var_data() if self._func is not None else None,
  223. *[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
  224. self._var_data,
  225. )
  226. @classmethod
  227. def create(
  228. cls,
  229. func: FunctionVar[ReflexCallable[P, R]],
  230. *args: Var | Any,
  231. _var_type: GenericType = Any,
  232. _var_data: VarData | None = None,
  233. ) -> VarOperationCall:
  234. """Create a new function call var.
  235. Args:
  236. func: The function to call.
  237. *args: The arguments to call the function with.
  238. _var_data: Additional hooks and imports associated with the Var.
  239. Returns:
  240. The function call var.
  241. """
  242. function_return_type = (
  243. func._var_type.__args__[1]
  244. if getattr(func._var_type, "__args__", None)
  245. else Any
  246. )
  247. var_type = _var_type if _var_type is not Any else function_return_type
  248. return cls(
  249. _js_expr="",
  250. _var_type=var_type,
  251. _var_data=_var_data,
  252. _func=func,
  253. _args=args,
  254. )
  255. @dataclasses.dataclass(frozen=True)
  256. class DestructuredArg:
  257. """Class for destructured arguments."""
  258. fields: Tuple[str, ...] = tuple()
  259. rest: Optional[str] = None
  260. def to_javascript(self) -> str:
  261. """Convert the destructured argument to JavaScript.
  262. Returns:
  263. The destructured argument in JavaScript.
  264. """
  265. return format.wrap(
  266. ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
  267. "{",
  268. "}",
  269. )
  270. @dataclasses.dataclass(
  271. frozen=True,
  272. )
  273. class FunctionArgs:
  274. """Class for function arguments."""
  275. args: Tuple[Union[str, DestructuredArg], ...] = tuple()
  276. rest: Optional[str] = None
  277. def format_args_function_operation(
  278. args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
  279. ) -> str:
  280. """Format an args function operation.
  281. Args:
  282. args: The function arguments.
  283. return_expr: The return expression.
  284. explicit_return: Whether to use explicit return syntax.
  285. Returns:
  286. The formatted args function operation.
  287. """
  288. arg_names_str = ", ".join(
  289. [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
  290. + ([f"...{args.rest}"] if args.rest else [])
  291. )
  292. return_expr_str = str(LiteralVar.create(return_expr))
  293. # Wrap return expression in curly braces if explicit return syntax is used.
  294. return_expr_str_wrapped = (
  295. format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
  296. )
  297. return f"(({arg_names_str}) => {return_expr_str_wrapped})"
  298. @dataclasses.dataclass(
  299. eq=False,
  300. frozen=True,
  301. **{"slots": True} if sys.version_info >= (3, 10) else {},
  302. )
  303. class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
  304. """Base class for immutable function defined via arguments and return expression."""
  305. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  306. _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
  307. default_factory=tuple
  308. )
  309. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  310. _function_name: str = dataclasses.field(default="")
  311. _explicit_return: bool = dataclasses.field(default=False)
  312. @cached_property_no_lock
  313. def _cached_var_name(self) -> str:
  314. """The name of the var.
  315. Returns:
  316. The name of the var.
  317. """
  318. return format_args_function_operation(
  319. self._args, self._return_expr, self._explicit_return
  320. )
  321. def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
  322. """Check if the function can be called with the given arguments.
  323. Args:
  324. *args: The arguments to call the function with.
  325. Returns:
  326. True if the function can be called with the given arguments.
  327. """
  328. for i, (validator, arg) in enumerate(zip(self._validators, args)):
  329. if not validator(arg):
  330. arg_name = self._args.args[i] if i < len(self._args.args) else None
  331. if arg_name is not None:
  332. raise VarTypeError(
  333. f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
  334. )
  335. raise VarTypeError(
  336. f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
  337. )
  338. return self._validators[len(args) :]
  339. @classmethod
  340. def create(
  341. cls,
  342. args_names: Sequence[Union[str, DestructuredArg]],
  343. return_expr: Var | Any,
  344. rest: str | None = None,
  345. validators: Sequence[Callable[[Any], bool]] = (),
  346. function_name: str = "",
  347. explicit_return: bool = False,
  348. _var_type: GenericType = Callable,
  349. _var_data: VarData | None = None,
  350. ):
  351. """Create a new function var.
  352. Args:
  353. args_names: The names of the arguments.
  354. return_expr: The return expression of the function.
  355. rest: The name of the rest argument.
  356. validators: The validators for the arguments.
  357. function_name: The name of the function.
  358. explicit_return: Whether to use explicit return syntax.
  359. _var_data: Additional hooks and imports associated with the Var.
  360. Returns:
  361. The function var.
  362. """
  363. return cls(
  364. _js_expr="",
  365. _var_type=_var_type,
  366. _var_data=_var_data,
  367. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  368. _function_name=function_name,
  369. _validators=tuple(validators),
  370. _return_expr=return_expr,
  371. _explicit_return=explicit_return,
  372. )
  373. @dataclasses.dataclass(
  374. eq=False,
  375. frozen=True,
  376. **{"slots": True} if sys.version_info >= (3, 10) else {},
  377. )
  378. class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
  379. """Base class for immutable function defined via arguments and return expression with the builder pattern."""
  380. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  381. _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
  382. default_factory=tuple
  383. )
  384. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  385. _function_name: str = dataclasses.field(default="")
  386. _explicit_return: bool = dataclasses.field(default=False)
  387. @cached_property_no_lock
  388. def _cached_var_name(self) -> str:
  389. """The name of the var.
  390. Returns:
  391. The name of the var.
  392. """
  393. return format_args_function_operation(
  394. self._args, self._return_expr, self._explicit_return
  395. )
  396. def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
  397. """Check if the function can be called with the given arguments.
  398. Args:
  399. *args: The arguments to call the function with.
  400. Returns:
  401. True if the function can be called with the given arguments.
  402. """
  403. for i, (validator, arg) in enumerate(zip(self._validators, args)):
  404. if not validator(arg):
  405. arg_name = self._args.args[i] if i < len(self._args.args) else None
  406. if arg_name is not None:
  407. raise VarTypeError(
  408. f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
  409. )
  410. raise VarTypeError(
  411. f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
  412. )
  413. return self._validators[len(args) :]
  414. @classmethod
  415. def create(
  416. cls,
  417. args_names: Sequence[Union[str, DestructuredArg]],
  418. return_expr: Var | Any,
  419. rest: str | None = None,
  420. validators: Sequence[Callable[[Any], bool]] = (),
  421. function_name: str = "",
  422. explicit_return: bool = False,
  423. _var_type: GenericType = Callable,
  424. _var_data: VarData | None = None,
  425. ):
  426. """Create a new function var.
  427. Args:
  428. args_names: The names of the arguments.
  429. return_expr: The return expression of the function.
  430. rest: The name of the rest argument.
  431. validators: The validators for the arguments.
  432. function_name: The name of the function.
  433. explicit_return: Whether to use explicit return syntax.
  434. _var_data: Additional hooks and imports associated with the Var.
  435. Returns:
  436. The function var.
  437. """
  438. return cls(
  439. _js_expr="",
  440. _var_type=_var_type,
  441. _var_data=_var_data,
  442. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  443. _function_name=function_name,
  444. _validators=tuple(validators),
  445. _return_expr=return_expr,
  446. _explicit_return=explicit_return,
  447. )
  448. if python_version := sys.version_info[:2] >= (3, 10):
  449. JSON_STRINGIFY = FunctionStringVar.create(
  450. "JSON.stringify", _var_type=ReflexCallable[[Any], str]
  451. )
  452. ARRAY_ISARRAY = FunctionStringVar.create(
  453. "Array.isArray", _var_type=ReflexCallable[[Any], bool]
  454. )
  455. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  456. "((__to_string) => __to_string.toString())",
  457. _var_type=ReflexCallable[[Any], str],
  458. )
  459. else:
  460. JSON_STRINGIFY = FunctionStringVar.create(
  461. "JSON.stringify", _var_type=ReflexCallable[Any, str]
  462. )
  463. ARRAY_ISARRAY = FunctionStringVar.create(
  464. "Array.isArray", _var_type=ReflexCallable[Any, bool]
  465. )
  466. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  467. "((__to_string) => __to_string.toString())",
  468. _var_type=ReflexCallable[Any, str],
  469. )