client_state.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. """Handle client side state with `useState`."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import re
  5. from collections.abc import Callable
  6. from typing import Any
  7. from reflex import constants
  8. from reflex.event import EventChain, EventHandler, EventSpec, run_script
  9. from reflex.utils.imports import ImportVar
  10. from reflex.vars import VarData, get_unique_variable_name
  11. from reflex.vars.base import LiteralVar, Var
  12. from reflex.vars.function import ArgsFunctionOperationBuilder, FunctionVar
  13. NoValue = object()
  14. _refs_import = {
  15. f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
  16. }
  17. def _client_state_ref(var_name: str) -> str:
  18. """Get the ref path for a ClientStateVar.
  19. Args:
  20. var_name: The name of the variable.
  21. Returns:
  22. An accessor for ClientStateVar ref as a string.
  23. """
  24. return f"refs['_client_state_{var_name}']"
  25. def _client_state_ref_dict(var_name: str) -> str:
  26. """Get the ref path for a ClientStateVar.
  27. Args:
  28. var_name: The name of the variable.
  29. Returns:
  30. An accessor for ClientStateVar ref as a string.
  31. """
  32. return f"refs['_client_state_dict_{var_name}']"
  33. @dataclasses.dataclass(
  34. eq=False,
  35. frozen=True,
  36. slots=True,
  37. )
  38. class ClientStateVar(Var):
  39. """A Var that exists on the client via useState."""
  40. # Track the names of the getters and setters
  41. _setter_name: str = dataclasses.field(default="")
  42. _getter_name: str = dataclasses.field(default="")
  43. _id_name: str = dataclasses.field(default="")
  44. # Whether to add the var and setter to the global `refs` object for use in any Component.
  45. _global_ref: bool = dataclasses.field(default=True)
  46. def __hash__(self) -> int:
  47. """Define a hash function for a var.
  48. Returns:
  49. The hash of the var.
  50. """
  51. return hash(
  52. (self._js_expr, str(self._var_type), self._getter_name, self._setter_name)
  53. )
  54. @classmethod
  55. def create(
  56. cls,
  57. var_name: str | None = None,
  58. default: Any = NoValue,
  59. global_ref: bool = True,
  60. ) -> ClientStateVar:
  61. """Create a local_state Var that can be accessed and updated on the client.
  62. The `ClientStateVar` should be included in the highest parent component
  63. that contains the components which will access and manipulate the client
  64. state. It has no visual rendering, including it ensures that the
  65. `useState` hook is called in the correct scope.
  66. To render the var in a component, use the `value` property.
  67. To update the var in a component, use the `set` property or `set_value` method.
  68. To access the var in an event handler, use the `retrieve` method with
  69. `callback` set to the event handler which should receive the value.
  70. To update the var in an event handler, use the `push` method with the
  71. value to update.
  72. Args:
  73. var_name: The name of the variable.
  74. default: The default value of the variable.
  75. global_ref: Whether the state should be accessible in any Component and on the backend.
  76. Raises:
  77. ValueError: If the var_name is not a string.
  78. Returns:
  79. ClientStateVar
  80. """
  81. if var_name is None:
  82. var_name = get_unique_variable_name()
  83. id_name = "id_" + get_unique_variable_name()
  84. if not isinstance(var_name, str):
  85. raise ValueError("var_name must be a string.")
  86. if default is NoValue:
  87. default_var = Var(_js_expr="")
  88. elif not isinstance(default, Var):
  89. default_var = LiteralVar.create(default)
  90. else:
  91. default_var = default
  92. setter_name = f"set{var_name.capitalize()}"
  93. hooks: dict[str, VarData | None] = {
  94. f"const {id_name} = useId()": None,
  95. f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
  96. }
  97. imports = {
  98. "react": [ImportVar(tag="useState"), ImportVar(tag="useId")],
  99. }
  100. if global_ref:
  101. arg_name = get_unique_variable_name()
  102. func = ArgsFunctionOperationBuilder.create(
  103. args_names=(arg_name,),
  104. return_expr=Var("Array.prototype.forEach.call")
  105. .to(FunctionVar)
  106. .call(
  107. (
  108. Var("Object.values")
  109. .to(FunctionVar)
  110. .call(Var(_client_state_ref_dict(setter_name)))
  111. .to(list)
  112. .to(list)
  113. )
  114. + Var.create(
  115. [
  116. Var(
  117. f"(value) => {{ {_client_state_ref(var_name)} = value; }}"
  118. )
  119. ]
  120. ).to(list),
  121. ArgsFunctionOperationBuilder.create(
  122. args_names=("setter",),
  123. return_expr=Var("setter").to(FunctionVar).call(Var(arg_name)),
  124. ),
  125. ),
  126. )
  127. hooks[f"{_client_state_ref(setter_name)} = {func!s}"] = None
  128. hooks[f"{_client_state_ref(var_name)} ??= {var_name!s}"] = None
  129. hooks[f"{_client_state_ref_dict(var_name)} ??= {{}}"] = None
  130. hooks[f"{_client_state_ref_dict(setter_name)} ??= {{}}"] = None
  131. hooks[f"{_client_state_ref_dict(var_name)}[{id_name}] = {var_name}"] = None
  132. hooks[
  133. f"{_client_state_ref_dict(setter_name)}[{id_name}] = {setter_name}"
  134. ] = None
  135. imports.update(_refs_import)
  136. return cls(
  137. _js_expr="",
  138. _setter_name=setter_name,
  139. _getter_name=var_name,
  140. _id_name=id_name,
  141. _global_ref=global_ref,
  142. _var_type=default_var._var_type,
  143. _var_data=VarData.merge(
  144. default_var._var_data,
  145. VarData(
  146. hooks=hooks,
  147. imports=imports,
  148. ),
  149. ),
  150. )
  151. @property
  152. def value(self) -> Var:
  153. """Get a placeholder for the Var.
  154. This property can only be rendered on the frontend.
  155. To access the value in a backend event handler, see `retrieve`.
  156. Returns:
  157. an accessor for the client state variable.
  158. """
  159. return (
  160. Var(
  161. _js_expr=(
  162. _client_state_ref_dict(self._getter_name) + f"[{self._id_name}]"
  163. if self._global_ref
  164. else self._getter_name
  165. ),
  166. _var_data=self._var_data,
  167. )
  168. .to(self._var_type)
  169. ._replace(
  170. merge_var_data=VarData(imports=_refs_import if self._global_ref else {})
  171. )
  172. )
  173. def set_value(self, value: Any = NoValue) -> Var:
  174. """Set the value of the client state variable.
  175. This property can only be attached to a frontend event trigger.
  176. To set a value from a backend event handler, see `push`.
  177. Args:
  178. value: The value to set.
  179. Returns:
  180. A special EventChain Var which will set the value when triggered.
  181. """
  182. _var_data = VarData(imports=_refs_import if self._global_ref else {})
  183. setter = (
  184. Var(_client_state_ref(self._setter_name))
  185. if self._global_ref
  186. else Var(self._setter_name, _var_data=_var_data)
  187. ).to(FunctionVar)
  188. if value is not NoValue:
  189. # This is a hack to make it work like an EventSpec taking an arg
  190. value_var = LiteralVar.create(value)
  191. value_str = str(value_var)
  192. setter = ArgsFunctionOperationBuilder.create(
  193. # remove patterns of ["*"] from the value_str using regex
  194. args_names=(re.sub(r"\[\".*\"\]", "", value_str),)
  195. if value_str.startswith("_")
  196. else (),
  197. return_expr=setter.call(value_var),
  198. )
  199. return setter.to(FunctionVar, EventChain)
  200. @property
  201. def set(self) -> Var:
  202. """Set the value of the client state variable.
  203. This property can only be attached to a frontend event trigger.
  204. To set a value from a backend event handler, see `push`.
  205. Returns:
  206. A special EventChain Var which will set the value when triggered.
  207. """
  208. return self.set_value()
  209. def retrieve(self, callback: EventHandler | Callable | None = None) -> EventSpec:
  210. """Pass the value of the client state variable to a backend EventHandler.
  211. The event handler must `yield` or `return` the EventSpec to trigger the event.
  212. Args:
  213. callback: The callback to pass the value to.
  214. Returns:
  215. An EventSpec which will retrieve the value when triggered.
  216. Raises:
  217. ValueError: If the ClientStateVar is not global.
  218. """
  219. if not self._global_ref:
  220. raise ValueError("ClientStateVar must be global to retrieve the value.")
  221. return run_script(_client_state_ref(self._getter_name), callback=callback)
  222. def push(self, value: Any) -> EventSpec:
  223. """Push a value to the client state variable from the backend.
  224. The event handler must `yield` or `return` the EventSpec to trigger the event.
  225. Args:
  226. value: The value to update.
  227. Returns:
  228. An EventSpec which will push the value when triggered.
  229. Raises:
  230. ValueError: If the ClientStateVar is not global.
  231. """
  232. if not self._global_ref:
  233. raise ValueError("ClientStateVar must be global to push the value.")
  234. value = Var.create(value)
  235. return run_script(f"{_client_state_ref(self._setter_name)}({value})")