function.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. """Immutable function vars."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import sys
  5. from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
  6. from typing_extensions import Concatenate, Generic, ParamSpec, TypeVar
  7. from reflex.utils import format
  8. from reflex.utils.exceptions import VarTypeError
  9. from reflex.utils.types import GenericType
  10. from .base import (
  11. CachedVarOperation,
  12. LiteralVar,
  13. ReflexCallable,
  14. TypeComputer,
  15. Var,
  16. VarData,
  17. cached_property_no_lock,
  18. unwrap_reflex_callalbe,
  19. )
  20. P = ParamSpec("P")
  21. R = TypeVar("R")
  22. V1 = TypeVar("V1")
  23. V2 = TypeVar("V2")
  24. V3 = TypeVar("V3")
  25. V4 = TypeVar("V4")
  26. V5 = TypeVar("V5")
  27. V6 = TypeVar("V6")
  28. CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
  29. OTHER_CALLABLE_TYPE = TypeVar(
  30. "OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True
  31. )
  32. class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
  33. """Base class for immutable function vars."""
  34. @overload
  35. def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
  36. @overload
  37. def partial(
  38. self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
  39. arg1: Union[V1, Var[V1]],
  40. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  41. @overload
  42. def partial(
  43. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
  44. arg1: Union[V1, Var[V1]],
  45. arg2: Union[V2, Var[V2]],
  46. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  47. @overload
  48. def partial(
  49. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
  50. arg1: Union[V1, Var[V1]],
  51. arg2: Union[V2, Var[V2]],
  52. arg3: Union[V3, Var[V3]],
  53. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  54. @overload
  55. def partial(
  56. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
  57. arg1: Union[V1, Var[V1]],
  58. arg2: Union[V2, Var[V2]],
  59. arg3: Union[V3, Var[V3]],
  60. arg4: Union[V4, Var[V4]],
  61. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  62. @overload
  63. def partial(
  64. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
  65. arg1: Union[V1, Var[V1]],
  66. arg2: Union[V2, Var[V2]],
  67. arg3: Union[V3, Var[V3]],
  68. arg4: Union[V4, Var[V4]],
  69. arg5: Union[V5, Var[V5]],
  70. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  71. @overload
  72. def partial(
  73. self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
  74. arg1: Union[V1, Var[V1]],
  75. arg2: Union[V2, Var[V2]],
  76. arg3: Union[V3, Var[V3]],
  77. arg4: Union[V4, Var[V4]],
  78. arg5: Union[V5, Var[V5]],
  79. arg6: Union[V6, Var[V6]],
  80. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  81. @overload
  82. def partial(
  83. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  84. ) -> FunctionVar[ReflexCallable[P, R]]: ...
  85. @overload
  86. def partial(self, *args: Var | Any) -> FunctionVar: ...
  87. def partial(self, *args: Var | Any) -> FunctionVar: # type: ignore
  88. """Partially apply the function with the given arguments.
  89. Args:
  90. *args: The arguments to partially apply the function with.
  91. Returns:
  92. The partially applied function.
  93. """
  94. if not args:
  95. return self
  96. args = tuple(map(LiteralVar.create, args))
  97. remaining_validators = self._pre_check(*args)
  98. partial_types, type_computer = self._partial_type(*args)
  99. if self.__call__ is self.partial:
  100. # if the default behavior is partial, we should return a new partial function
  101. return ArgsFunctionOperationBuilder.create(
  102. (),
  103. VarOperationCall.create(
  104. self,
  105. *args,
  106. Var(_js_expr="...args"),
  107. _var_type=self._return_type(*args),
  108. ),
  109. rest="args",
  110. validators=remaining_validators,
  111. type_computer=type_computer,
  112. _var_type=partial_types,
  113. )
  114. return ArgsFunctionOperation.create(
  115. (),
  116. VarOperationCall.create(
  117. self, *args, Var(_js_expr="...args"), _var_type=self._return_type(*args)
  118. ),
  119. rest="args",
  120. validators=remaining_validators,
  121. type_computer=type_computer,
  122. _var_type=partial_types,
  123. )
  124. @overload
  125. def call(
  126. self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
  127. ) -> VarOperationCall[[V1], R]: ...
  128. @overload
  129. def call(
  130. self: FunctionVar[ReflexCallable[[V1, V2], R]],
  131. arg1: Union[V1, Var[V1]],
  132. arg2: Union[V2, Var[V2]],
  133. ) -> VarOperationCall[[V1, V2], R]: ...
  134. @overload
  135. def call(
  136. self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
  137. arg1: Union[V1, Var[V1]],
  138. arg2: Union[V2, Var[V2]],
  139. arg3: Union[V3, Var[V3]],
  140. ) -> VarOperationCall[[V1, V2, V3], R]: ...
  141. @overload
  142. def call(
  143. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
  144. arg1: Union[V1, Var[V1]],
  145. arg2: Union[V2, Var[V2]],
  146. arg3: Union[V3, Var[V3]],
  147. arg4: Union[V4, Var[V4]],
  148. ) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
  149. @overload
  150. def call(
  151. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
  152. arg1: Union[V1, Var[V1]],
  153. arg2: Union[V2, Var[V2]],
  154. arg3: Union[V3, Var[V3]],
  155. arg4: Union[V4, Var[V4]],
  156. arg5: Union[V5, Var[V5]],
  157. ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
  158. @overload
  159. def call(
  160. self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
  161. arg1: Union[V1, Var[V1]],
  162. arg2: Union[V2, Var[V2]],
  163. arg3: Union[V3, Var[V3]],
  164. arg4: Union[V4, Var[V4]],
  165. arg5: Union[V5, Var[V5]],
  166. arg6: Union[V6, Var[V6]],
  167. ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
  168. @overload
  169. def call(
  170. self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
  171. ) -> VarOperationCall[P, R]: ...
  172. @overload
  173. def call(self, *args: Var | Any) -> Var: ...
  174. def call(self, *args: Var | Any) -> Var: # type: ignore
  175. """Call the function with the given arguments.
  176. Args:
  177. *args: The arguments to call the function with.
  178. Returns:
  179. The function call operation.
  180. Raises:
  181. VarTypeError: If the number of arguments is invalid
  182. """
  183. arg_len = self._arg_len()
  184. if arg_len is not None and len(args) != arg_len:
  185. raise VarTypeError(f"Invalid number of arguments provided to {str(self)}")
  186. args = tuple(map(LiteralVar.create, args))
  187. self._pre_check(*args)
  188. return_type = self._return_type(*args)
  189. return VarOperationCall.create(self, *args, _var_type=return_type).guess_type()
  190. def _partial_type(
  191. self, *args: Var | Any
  192. ) -> Tuple[GenericType, Optional[TypeComputer]]:
  193. """Override the type of the function call with the given arguments.
  194. Args:
  195. *args: The arguments to call the function with.
  196. Returns:
  197. The overridden type of the function call.
  198. """
  199. args_types, return_type = unwrap_reflex_callalbe(self._var_type)
  200. if isinstance(args_types, tuple):
  201. return ReflexCallable[[*args_types[len(args) :]], return_type], None # type: ignore
  202. return ReflexCallable[..., return_type], None
  203. def _arg_len(self) -> int | None:
  204. """Get the number of arguments the function takes.
  205. Returns:
  206. The number of arguments the function takes.
  207. """
  208. args_types, _ = unwrap_reflex_callalbe(self._var_type)
  209. if isinstance(args_types, tuple):
  210. return len(args_types)
  211. return None
  212. def _return_type(self, *args: Var | Any) -> GenericType:
  213. """Override the type of the function call with the given arguments.
  214. Args:
  215. *args: The arguments to call the function with.
  216. Returns:
  217. The overridden type of the function call.
  218. """
  219. partial_types, _ = self._partial_type(*args)
  220. return unwrap_reflex_callalbe(partial_types)[1]
  221. def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
  222. """Check if the function can be called with the given arguments.
  223. Args:
  224. *args: The arguments to call the function with.
  225. Returns:
  226. True if the function can be called with the given arguments.
  227. """
  228. return tuple()
  229. __call__ = call
  230. class BuilderFunctionVar(
  231. FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
  232. ):
  233. """Base class for immutable function vars with the builder pattern."""
  234. __call__ = FunctionVar.partial
  235. class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
  236. """Base class for immutable function vars from a string."""
  237. @classmethod
  238. def create(
  239. cls,
  240. func: str,
  241. _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
  242. _var_data: VarData | None = None,
  243. ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
  244. """Create a new function var from a string.
  245. Args:
  246. func: The function to call.
  247. _var_data: Additional hooks and imports associated with the Var.
  248. Returns:
  249. The function var.
  250. """
  251. return FunctionStringVar(
  252. _js_expr=func,
  253. _var_type=_var_type,
  254. _var_data=_var_data,
  255. )
  256. @dataclasses.dataclass(
  257. eq=False,
  258. frozen=True,
  259. **{"slots": True} if sys.version_info >= (3, 10) else {},
  260. )
  261. class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
  262. """Base class for immutable vars that are the result of a function call."""
  263. _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
  264. _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
  265. @cached_property_no_lock
  266. def _cached_var_name(self) -> str:
  267. """The name of the var.
  268. Returns:
  269. The name of the var.
  270. """
  271. return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
  272. @cached_property_no_lock
  273. def _cached_get_all_var_data(self) -> VarData | None:
  274. """Get all the var data associated with the var.
  275. Returns:
  276. All the var data associated with the var.
  277. """
  278. return VarData.merge(
  279. self._func._get_all_var_data() if self._func is not None else None,
  280. *[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
  281. self._var_data,
  282. )
  283. @classmethod
  284. def create(
  285. cls,
  286. func: FunctionVar[ReflexCallable[P, R]],
  287. *args: Var | Any,
  288. _var_type: GenericType = Any,
  289. _var_data: VarData | None = None,
  290. ) -> VarOperationCall:
  291. """Create a new function call var.
  292. Args:
  293. func: The function to call.
  294. *args: The arguments to call the function with.
  295. _var_data: Additional hooks and imports associated with the Var.
  296. Returns:
  297. The function call var.
  298. """
  299. function_return_type = (
  300. func._var_type.__args__[1]
  301. if getattr(func._var_type, "__args__", None)
  302. else Any
  303. )
  304. var_type = _var_type if _var_type is not Any else function_return_type
  305. return cls(
  306. _js_expr="",
  307. _var_type=var_type,
  308. _var_data=_var_data,
  309. _func=func,
  310. _args=args,
  311. )
  312. @dataclasses.dataclass(frozen=True)
  313. class DestructuredArg:
  314. """Class for destructured arguments."""
  315. fields: Tuple[str, ...] = tuple()
  316. rest: Optional[str] = None
  317. def to_javascript(self) -> str:
  318. """Convert the destructured argument to JavaScript.
  319. Returns:
  320. The destructured argument in JavaScript.
  321. """
  322. return format.wrap(
  323. ", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
  324. "{",
  325. "}",
  326. )
  327. @dataclasses.dataclass(
  328. frozen=True,
  329. )
  330. class FunctionArgs:
  331. """Class for function arguments."""
  332. args: Tuple[Union[str, DestructuredArg], ...] = tuple()
  333. rest: Optional[str] = None
  334. def format_args_function_operation(
  335. self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
  336. ) -> str:
  337. """Format an args function operation.
  338. Args:
  339. self: The function operation.
  340. args: The function arguments.
  341. return_expr: The return expression.
  342. explicit_return: Whether to use explicit return syntax.
  343. Returns:
  344. The formatted args function operation.
  345. """
  346. arg_names_str = ", ".join(
  347. [
  348. arg if isinstance(arg, str) else arg.to_javascript()
  349. for arg in self._args.args
  350. ]
  351. + ([f"...{self._args.rest}"] if self._args.rest else [])
  352. )
  353. return_expr_str = str(LiteralVar.create(self._return_expr))
  354. # Wrap return expression in curly braces if explicit return syntax is used.
  355. return_expr_str_wrapped = (
  356. format.wrap(return_expr_str, "{", "}")
  357. if self._explicit_return
  358. else return_expr_str
  359. )
  360. return f"(({arg_names_str}) => {return_expr_str_wrapped})"
  361. def pre_check_args(
  362. self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, *args: Var | Any
  363. ) -> Tuple[Callable[[Any], bool], ...]:
  364. """Check if the function can be called with the given arguments.
  365. Args:
  366. self: The function operation.
  367. *args: The arguments to call the function with.
  368. Returns:
  369. True if the function can be called with the given arguments.
  370. """
  371. for i, (validator, arg) in enumerate(zip(self._validators, args)):
  372. if not validator(arg):
  373. arg_name = self._args.args[i] if i < len(self._args.args) else None
  374. if arg_name is not None:
  375. raise VarTypeError(
  376. f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
  377. )
  378. raise VarTypeError(
  379. f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
  380. )
  381. return self._validators[len(args) :]
  382. def figure_partial_type(
  383. self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
  384. *args: Var | Any,
  385. ) -> Tuple[GenericType, Optional[TypeComputer]]:
  386. """Figure out the return type of the function.
  387. Args:
  388. self: The function operation.
  389. *args: The arguments to call the function with.
  390. Returns:
  391. The return type of the function.
  392. """
  393. return (
  394. self._type_computer(*args)
  395. if self._type_computer is not None
  396. else FunctionVar._partial_type(self, *args)
  397. )
  398. @dataclasses.dataclass(
  399. eq=False,
  400. frozen=True,
  401. **{"slots": True} if sys.version_info >= (3, 10) else {},
  402. )
  403. class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]):
  404. """Base class for immutable function defined via arguments and return expression."""
  405. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  406. _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
  407. default_factory=tuple
  408. )
  409. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  410. _function_name: str = dataclasses.field(default="")
  411. _type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
  412. _explicit_return: bool = dataclasses.field(default=False)
  413. _cached_var_name = cached_property_no_lock(format_args_function_operation)
  414. _pre_check = pre_check_args # type: ignore
  415. _partial_type = figure_partial_type # type: ignore
  416. @classmethod
  417. def create(
  418. cls,
  419. args_names: Sequence[Union[str, DestructuredArg]],
  420. return_expr: Var | Any,
  421. rest: str | None = None,
  422. validators: Sequence[Callable[[Any], bool]] = (),
  423. function_name: str = "",
  424. explicit_return: bool = False,
  425. type_computer: Optional[TypeComputer] = None,
  426. _var_type: GenericType = Callable,
  427. _var_data: VarData | None = None,
  428. ):
  429. """Create a new function var.
  430. Args:
  431. args_names: The names of the arguments.
  432. return_expr: The return expression of the function.
  433. rest: The name of the rest argument.
  434. validators: The validators for the arguments.
  435. function_name: The name of the function.
  436. explicit_return: Whether to use explicit return syntax.
  437. type_computer: A function to compute the return type.
  438. _var_type: The type of the var.
  439. _var_data: Additional hooks and imports associated with the Var.
  440. Returns:
  441. The function var.
  442. """
  443. return cls(
  444. _js_expr="",
  445. _var_type=_var_type,
  446. _var_data=_var_data,
  447. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  448. _function_name=function_name,
  449. _validators=tuple(validators),
  450. _return_expr=return_expr,
  451. _explicit_return=explicit_return,
  452. _type_computer=type_computer,
  453. )
  454. @dataclasses.dataclass(
  455. eq=False,
  456. frozen=True,
  457. **{"slots": True} if sys.version_info >= (3, 10) else {},
  458. )
  459. class ArgsFunctionOperationBuilder(
  460. CachedVarOperation, BuilderFunctionVar[CALLABLE_TYPE]
  461. ):
  462. """Base class for immutable function defined via arguments and return expression with the builder pattern."""
  463. _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
  464. _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
  465. default_factory=tuple
  466. )
  467. _return_expr: Union[Var, Any] = dataclasses.field(default=None)
  468. _function_name: str = dataclasses.field(default="")
  469. _type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
  470. _explicit_return: bool = dataclasses.field(default=False)
  471. _cached_var_name = cached_property_no_lock(format_args_function_operation)
  472. _pre_check = pre_check_args # type: ignore
  473. _partial_type = figure_partial_type # type: ignore
  474. @classmethod
  475. def create(
  476. cls,
  477. args_names: Sequence[Union[str, DestructuredArg]],
  478. return_expr: Var | Any,
  479. rest: str | None = None,
  480. validators: Sequence[Callable[[Any], bool]] = (),
  481. function_name: str = "",
  482. explicit_return: bool = False,
  483. type_computer: Optional[TypeComputer] = None,
  484. _var_type: GenericType = Callable,
  485. _var_data: VarData | None = None,
  486. ):
  487. """Create a new function var.
  488. Args:
  489. args_names: The names of the arguments.
  490. return_expr: The return expression of the function.
  491. rest: The name of the rest argument.
  492. validators: The validators for the arguments.
  493. function_name: The name of the function.
  494. explicit_return: Whether to use explicit return syntax.
  495. type_computer: A function to compute the return type.
  496. _var_type: The type of the var.
  497. _var_data: Additional hooks and imports associated with the Var.
  498. Returns:
  499. The function var.
  500. """
  501. return cls(
  502. _js_expr="",
  503. _var_type=_var_type,
  504. _var_data=_var_data,
  505. _args=FunctionArgs(args=tuple(args_names), rest=rest),
  506. _function_name=function_name,
  507. _validators=tuple(validators),
  508. _return_expr=return_expr,
  509. _explicit_return=explicit_return,
  510. _type_computer=type_computer,
  511. )
  512. if python_version := sys.version_info[:2] >= (3, 10):
  513. JSON_STRINGIFY = FunctionStringVar.create(
  514. "JSON.stringify", _var_type=ReflexCallable[[Any], str]
  515. )
  516. ARRAY_ISARRAY = FunctionStringVar.create(
  517. "Array.isArray", _var_type=ReflexCallable[[Any], bool]
  518. )
  519. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  520. "((__to_string) => __to_string.toString())",
  521. _var_type=ReflexCallable[[Any], str],
  522. )
  523. else:
  524. JSON_STRINGIFY = FunctionStringVar.create(
  525. "JSON.stringify", _var_type=ReflexCallable[Any, str]
  526. )
  527. ARRAY_ISARRAY = FunctionStringVar.create(
  528. "Array.isArray", _var_type=ReflexCallable[Any, bool]
  529. )
  530. PROTOTYPE_TO_STRING = FunctionStringVar.create(
  531. "((__to_string) => __to_string.toString())",
  532. _var_type=ReflexCallable[Any, str],
  533. )