client_state.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """Handle client side state with `useState`."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import re
  5. from typing import Any, Callable
  6. from reflex import constants
  7. from reflex.event import EventChain, EventHandler, EventSpec, run_script
  8. from reflex.utils.imports import ImportVar
  9. from reflex.vars import VarData, get_unique_variable_name
  10. from reflex.vars.base import LiteralVar, Var
  11. from reflex.vars.function import ArgsFunctionOperationBuilder, FunctionVar
  12. NoValue = object()
  13. _refs_import = {
  14. f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
  15. }
  16. def _client_state_ref(var_name: str) -> str:
  17. """Get the ref path for a ClientStateVar.
  18. Args:
  19. var_name: The name of the variable.
  20. Returns:
  21. An accessor for ClientStateVar ref as a string.
  22. """
  23. return f"refs['_client_state_{var_name}']"
  24. def _client_state_ref_dict(var_name: str) -> str:
  25. """Get the ref path for a ClientStateVar.
  26. Args:
  27. var_name: The name of the variable.
  28. Returns:
  29. An accessor for ClientStateVar ref as a string.
  30. """
  31. return f"refs['_client_state_dict_{var_name}']"
  32. @dataclasses.dataclass(
  33. eq=False,
  34. frozen=True,
  35. slots=True,
  36. )
  37. class ClientStateVar(Var):
  38. """A Var that exists on the client via useState."""
  39. # Track the names of the getters and setters
  40. _setter_name: str = dataclasses.field(default="")
  41. _getter_name: str = dataclasses.field(default="")
  42. _id_name: str = dataclasses.field(default="")
  43. # Whether to add the var and setter to the global `refs` object for use in any Component.
  44. _global_ref: bool = dataclasses.field(default=True)
  45. def __hash__(self) -> int:
  46. """Define a hash function for a var.
  47. Returns:
  48. The hash of the var.
  49. """
  50. return hash(
  51. (self._js_expr, str(self._var_type), self._getter_name, self._setter_name)
  52. )
  53. @classmethod
  54. def create(
  55. cls,
  56. var_name: str | None = None,
  57. default: Any = NoValue,
  58. global_ref: bool = True,
  59. ) -> "ClientStateVar":
  60. """Create a local_state Var that can be accessed and updated on the client.
  61. The `ClientStateVar` should be included in the highest parent component
  62. that contains the components which will access and manipulate the client
  63. state. It has no visual rendering, including it ensures that the
  64. `useState` hook is called in the correct scope.
  65. To render the var in a component, use the `value` property.
  66. To update the var in a component, use the `set` property or `set_value` method.
  67. To access the var in an event handler, use the `retrieve` method with
  68. `callback` set to the event handler which should receive the value.
  69. To update the var in an event handler, use the `push` method with the
  70. value to update.
  71. Args:
  72. var_name: The name of the variable.
  73. default: The default value of the variable.
  74. global_ref: Whether the state should be accessible in any Component and on the backend.
  75. Raises:
  76. ValueError: If the var_name is not a string.
  77. Returns:
  78. ClientStateVar
  79. """
  80. if var_name is None:
  81. var_name = get_unique_variable_name()
  82. id_name = "id_" + get_unique_variable_name()
  83. if not isinstance(var_name, str):
  84. raise ValueError("var_name must be a string.")
  85. if default is NoValue:
  86. default_var = Var(_js_expr="")
  87. elif not isinstance(default, Var):
  88. default_var = LiteralVar.create(default)
  89. else:
  90. default_var = default
  91. setter_name = f"set{var_name.capitalize()}"
  92. hooks: dict[str, VarData | None] = {
  93. f"const {id_name} = useId()": None,
  94. f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
  95. }
  96. imports = {
  97. "react": [ImportVar(tag="useState"), ImportVar(tag="useId")],
  98. }
  99. if global_ref:
  100. arg_name = get_unique_variable_name()
  101. func = ArgsFunctionOperationBuilder.create(
  102. args_names=(arg_name,),
  103. return_expr=Var("Array.prototype.forEach.call")
  104. .to(FunctionVar)
  105. .call(
  106. (
  107. Var("Object.values")
  108. .to(FunctionVar)
  109. .call(Var(_client_state_ref_dict(setter_name)))
  110. .to(list)
  111. .to(list)
  112. )
  113. + Var.create(
  114. [
  115. Var(
  116. f"(value) => {{ {_client_state_ref(var_name)} = value; }}"
  117. )
  118. ]
  119. ).to(list),
  120. ArgsFunctionOperationBuilder.create(
  121. args_names=("setter",),
  122. return_expr=Var("setter").to(FunctionVar).call(Var(arg_name)),
  123. ),
  124. ),
  125. )
  126. hooks[f"{_client_state_ref(setter_name)} = {func!s}"] = None
  127. hooks[f"{_client_state_ref(var_name)} ??= {var_name!s}"] = None
  128. hooks[f"{_client_state_ref_dict(var_name)} ??= {{}}"] = None
  129. hooks[f"{_client_state_ref_dict(setter_name)} ??= {{}}"] = None
  130. hooks[f"{_client_state_ref_dict(var_name)}[{id_name}] = {var_name}"] = None
  131. hooks[
  132. f"{_client_state_ref_dict(setter_name)}[{id_name}] = {setter_name}"
  133. ] = None
  134. imports.update(_refs_import)
  135. return cls(
  136. _js_expr="",
  137. _setter_name=setter_name,
  138. _getter_name=var_name,
  139. _id_name=id_name,
  140. _global_ref=global_ref,
  141. _var_type=default_var._var_type,
  142. _var_data=VarData.merge(
  143. default_var._var_data,
  144. VarData(
  145. hooks=hooks,
  146. imports=imports,
  147. ),
  148. ),
  149. )
  150. @property
  151. def value(self) -> Var:
  152. """Get a placeholder for the Var.
  153. This property can only be rendered on the frontend.
  154. To access the value in a backend event handler, see `retrieve`.
  155. Returns:
  156. an accessor for the client state variable.
  157. """
  158. return (
  159. Var(
  160. _js_expr=(
  161. _client_state_ref_dict(self._getter_name) + f"[{self._id_name}]"
  162. if self._global_ref
  163. else self._getter_name
  164. ),
  165. _var_data=self._var_data,
  166. )
  167. .to(self._var_type)
  168. ._replace(
  169. merge_var_data=VarData(imports=_refs_import if self._global_ref else {})
  170. )
  171. )
  172. def set_value(self, value: Any = NoValue) -> Var:
  173. """Set the value of the client state variable.
  174. This property can only be attached to a frontend event trigger.
  175. To set a value from a backend event handler, see `push`.
  176. Args:
  177. value: The value to set.
  178. Returns:
  179. A special EventChain Var which will set the value when triggered.
  180. """
  181. _var_data = VarData(imports=_refs_import if self._global_ref else {})
  182. setter = (
  183. Var(_client_state_ref(self._setter_name))
  184. if self._global_ref
  185. else Var(self._setter_name, _var_data=_var_data)
  186. ).to(FunctionVar)
  187. if value is not NoValue:
  188. # This is a hack to make it work like an EventSpec taking an arg
  189. value_var = LiteralVar.create(value)
  190. value_str = str(value_var)
  191. setter = ArgsFunctionOperationBuilder.create(
  192. # remove patterns of ["*"] from the value_str using regex
  193. args_names=(re.sub(r"\[\".*\"\]", "", value_str),)
  194. if value_str.startswith("_")
  195. else (),
  196. return_expr=setter.call(value_var),
  197. )
  198. return setter.to(FunctionVar, EventChain)
  199. @property
  200. def set(self) -> Var:
  201. """Set the value of the client state variable.
  202. This property can only be attached to a frontend event trigger.
  203. To set a value from a backend event handler, see `push`.
  204. Returns:
  205. A special EventChain Var which will set the value when triggered.
  206. """
  207. return self.set_value()
  208. def retrieve(self, callback: EventHandler | Callable | None = None) -> EventSpec:
  209. """Pass the value of the client state variable to a backend EventHandler.
  210. The event handler must `yield` or `return` the EventSpec to trigger the event.
  211. Args:
  212. callback: The callback to pass the value to.
  213. Returns:
  214. An EventSpec which will retrieve the value when triggered.
  215. Raises:
  216. ValueError: If the ClientStateVar is not global.
  217. """
  218. if not self._global_ref:
  219. raise ValueError("ClientStateVar must be global to retrieve the value.")
  220. return run_script(_client_state_ref(self._getter_name), callback=callback)
  221. def push(self, value: Any) -> EventSpec:
  222. """Push a value to the client state variable from the backend.
  223. The event handler must `yield` or `return` the EventSpec to trigger the event.
  224. Args:
  225. value: The value to update.
  226. Returns:
  227. An EventSpec which will push the value when triggered.
  228. Raises:
  229. ValueError: If the ClientStateVar is not global.
  230. """
  231. if not self._global_ref:
  232. raise ValueError("ClientStateVar must be global to push the value.")
  233. value = Var.create(value)
  234. return run_script(f"{_client_state_ref(self._setter_name)}({value})")