client_state.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. """Handle client side state with `useState`."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import re
  5. import sys
  6. from typing import Any, Callable, Union
  7. from reflex import constants
  8. from reflex.event import EventChain, EventHandler, EventSpec, run_script
  9. from reflex.utils.imports import ImportVar
  10. from reflex.vars import VarData, get_unique_variable_name
  11. from reflex.vars.base import LiteralVar, Var
  12. from reflex.vars.function import ArgsFunctionOperationBuilder, FunctionVar
  13. NoValue = object()
  14. _refs_import = {
  15. f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
  16. }
  17. def _client_state_ref(var_name: str) -> str:
  18. """Get the ref path for a ClientStateVar.
  19. Args:
  20. var_name: The name of the variable.
  21. Returns:
  22. An accessor for ClientStateVar ref as a string.
  23. """
  24. return f"refs['_client_state_{var_name}']"
  25. @dataclasses.dataclass(
  26. eq=False,
  27. frozen=True,
  28. **{"slots": True} if sys.version_info >= (3, 10) else {},
  29. )
  30. class ClientStateVar(Var):
  31. """A Var that exists on the client via useState."""
  32. # Track the names of the getters and setters
  33. _setter_name: str = dataclasses.field(default="")
  34. _getter_name: str = dataclasses.field(default="")
  35. _id_name: str = dataclasses.field(default="")
  36. # Whether to add the var and setter to the global `refs` object for use in any Component.
  37. _global_ref: bool = dataclasses.field(default=True)
  38. def __hash__(self) -> int:
  39. """Define a hash function for a var.
  40. Returns:
  41. The hash of the var.
  42. """
  43. return hash(
  44. (self._js_expr, str(self._var_type), self._getter_name, self._setter_name)
  45. )
  46. @classmethod
  47. def create(
  48. cls,
  49. var_name: str | None = None,
  50. default: Any = NoValue,
  51. global_ref: bool = True,
  52. ) -> "ClientStateVar":
  53. """Create a local_state Var that can be accessed and updated on the client.
  54. The `ClientStateVar` should be included in the highest parent component
  55. that contains the components which will access and manipulate the client
  56. state. It has no visual rendering, including it ensures that the
  57. `useState` hook is called in the correct scope.
  58. To render the var in a component, use the `value` property.
  59. To update the var in a component, use the `set` property or `set_value` method.
  60. To access the var in an event handler, use the `retrieve` method with
  61. `callback` set to the event handler which should receive the value.
  62. To update the var in an event handler, use the `push` method with the
  63. value to update.
  64. Args:
  65. var_name: The name of the variable.
  66. default: The default value of the variable.
  67. global_ref: Whether the state should be accessible in any Component and on the backend.
  68. Raises:
  69. ValueError: If the var_name is not a string.
  70. Returns:
  71. ClientStateVar
  72. """
  73. if var_name is None:
  74. var_name = get_unique_variable_name()
  75. id_name = "id_" + get_unique_variable_name()
  76. if not isinstance(var_name, str):
  77. raise ValueError("var_name must be a string.")
  78. if default is NoValue:
  79. default_var = Var(_js_expr="")
  80. elif not isinstance(default, Var):
  81. default_var = LiteralVar.create(default)
  82. else:
  83. default_var = default
  84. setter_name = f"set{var_name.capitalize()}"
  85. hooks: dict[str, VarData | None] = {
  86. f"const {id_name} = useId()": None,
  87. f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
  88. }
  89. imports = {
  90. "react": [ImportVar(tag="useState"), ImportVar(tag="useId")],
  91. }
  92. if global_ref:
  93. hooks[f"{_client_state_ref(var_name)} ??= {{}}"] = None
  94. hooks[f"{_client_state_ref(setter_name)} ??= {{}}"] = None
  95. hooks[f"{_client_state_ref(var_name)}[{id_name}] = {var_name}"] = None
  96. hooks[f"{_client_state_ref(setter_name)}[{id_name}] = {setter_name}"] = None
  97. imports.update(_refs_import)
  98. return cls(
  99. _js_expr="",
  100. _setter_name=setter_name,
  101. _getter_name=var_name,
  102. _id_name=id_name,
  103. _global_ref=global_ref,
  104. _var_type=default_var._var_type,
  105. _var_data=VarData.merge(
  106. default_var._var_data,
  107. VarData(
  108. hooks=hooks,
  109. imports=imports,
  110. ),
  111. ),
  112. )
  113. @property
  114. def value(self) -> Var:
  115. """Get a placeholder for the Var.
  116. This property can only be rendered on the frontend.
  117. To access the value in a backend event handler, see `retrieve`.
  118. Returns:
  119. an accessor for the client state variable.
  120. """
  121. return (
  122. Var(
  123. _js_expr=(
  124. _client_state_ref(self._getter_name) + f"[{self._id_name}]"
  125. if self._global_ref
  126. else self._getter_name
  127. ),
  128. _var_data=self._var_data,
  129. )
  130. .to(self._var_type)
  131. ._replace(
  132. merge_var_data=VarData( # type: ignore
  133. imports=_refs_import if self._global_ref else {}
  134. )
  135. )
  136. )
  137. def set_value(self, value: Any = NoValue) -> Var:
  138. """Set the value of the client state variable.
  139. This property can only be attached to a frontend event trigger.
  140. To set a value from a backend event handler, see `push`.
  141. Args:
  142. value: The value to set.
  143. Returns:
  144. A special EventChain Var which will set the value when triggered.
  145. """
  146. _var_data = VarData(imports=_refs_import if self._global_ref else {})
  147. arg_name = get_unique_variable_name()
  148. setter = (
  149. ArgsFunctionOperationBuilder.create(
  150. args_names=(arg_name,),
  151. return_expr=Var("Array.prototype.forEach.call")
  152. .to(FunctionVar)
  153. .call(
  154. Var("Object.values")
  155. .to(FunctionVar)
  156. .call(Var(_client_state_ref(self._setter_name))),
  157. ArgsFunctionOperationBuilder.create(
  158. args_names=("setter",),
  159. return_expr=Var("setter").to(FunctionVar).call(Var(arg_name)),
  160. ),
  161. ),
  162. _var_data=_var_data,
  163. )
  164. if self._global_ref
  165. else Var(self._setter_name, _var_data=_var_data).to(FunctionVar)
  166. )
  167. if value is not NoValue:
  168. # This is a hack to make it work like an EventSpec taking an arg
  169. value_var = LiteralVar.create(value)
  170. value_str = str(value_var)
  171. setter = ArgsFunctionOperationBuilder.create(
  172. # remove patterns of ["*"] from the value_str using regex
  173. args_names=(re.sub(r"\[\".*\"\]", "", value_str),)
  174. if value_str.startswith("_")
  175. else (),
  176. return_expr=setter.call(value_var),
  177. )
  178. return setter.to(FunctionVar, EventChain)
  179. @property
  180. def set(self) -> Var:
  181. """Set the value of the client state variable.
  182. This property can only be attached to a frontend event trigger.
  183. To set a value from a backend event handler, see `push`.
  184. Returns:
  185. A special EventChain Var which will set the value when triggered.
  186. """
  187. return self.set_value()
  188. def retrieve(
  189. self, callback: Union[EventHandler, Callable, None] = None
  190. ) -> EventSpec:
  191. """Pass the value of the client state variable to a backend EventHandler.
  192. The event handler must `yield` or `return` the EventSpec to trigger the event.
  193. Args:
  194. callback: The callback to pass the value to.
  195. Returns:
  196. An EventSpec which will retrieve the value when triggered.
  197. Raises:
  198. ValueError: If the ClientStateVar is not global.
  199. """
  200. if not self._global_ref:
  201. raise ValueError("ClientStateVar must be global to retrieve the value.")
  202. return run_script(_client_state_ref(self._getter_name), callback=callback)
  203. def push(self, value: Any) -> EventSpec:
  204. """Push a value to the client state variable from the backend.
  205. The event handler must `yield` or `return` the EventSpec to trigger the event.
  206. Args:
  207. value: The value to update.
  208. Returns:
  209. An EventSpec which will push the value when triggered.
  210. Raises:
  211. ValueError: If the ClientStateVar is not global.
  212. """
  213. if not self._global_ref:
  214. raise ValueError("ClientStateVar must be global to push the value.")
  215. value = Var.create(value)
  216. return run_script(f"{_client_state_ref(self._setter_name)}({value})")