client_state.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. """Handle client side state with `useState`."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import re
  5. import sys
  6. from typing import Any, Callable, Union
  7. from reflex import constants
  8. from reflex.event import EventChain, EventHandler, EventSpec, run_script
  9. from reflex.utils.imports import ImportVar
  10. from reflex.vars import VarData, get_unique_variable_name
  11. from reflex.vars.base import LiteralVar, Var
  12. from reflex.vars.function import FunctionVar
  13. NoValue = object()
  14. _refs_import = {
  15. f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
  16. }
  17. def _client_state_ref(var_name: str) -> str:
  18. """Get the ref path for a ClientStateVar.
  19. Args:
  20. var_name: The name of the variable.
  21. Returns:
  22. An accessor for ClientStateVar ref as a string.
  23. """
  24. return f"refs['_client_state_{var_name}']"
  25. @dataclasses.dataclass(
  26. eq=False,
  27. frozen=True,
  28. **{"slots": True} if sys.version_info >= (3, 10) else {},
  29. )
  30. class ClientStateVar(Var):
  31. """A Var that exists on the client via useState."""
  32. # Track the names of the getters and setters
  33. _setter_name: str = dataclasses.field(default="")
  34. _getter_name: str = dataclasses.field(default="")
  35. # Whether to add the var and setter to the global `refs` object for use in any Component.
  36. _global_ref: bool = dataclasses.field(default=True)
  37. def __hash__(self) -> int:
  38. """Define a hash function for a var.
  39. Returns:
  40. The hash of the var.
  41. """
  42. return hash(
  43. (self._js_expr, str(self._var_type), self._getter_name, self._setter_name)
  44. )
  45. @classmethod
  46. def create(
  47. cls,
  48. var_name: str | None = None,
  49. default: Any = NoValue,
  50. global_ref: bool = True,
  51. ) -> "ClientStateVar":
  52. """Create a local_state Var that can be accessed and updated on the client.
  53. The `ClientStateVar` should be included in the highest parent component
  54. that contains the components which will access and manipulate the client
  55. state. It has no visual rendering, including it ensures that the
  56. `useState` hook is called in the correct scope.
  57. To render the var in a component, use the `value` property.
  58. To update the var in a component, use the `set` property or `set_value` method.
  59. To access the var in an event handler, use the `retrieve` method with
  60. `callback` set to the event handler which should receive the value.
  61. To update the var in an event handler, use the `push` method with the
  62. value to update.
  63. Args:
  64. var_name: The name of the variable.
  65. default: The default value of the variable.
  66. global_ref: Whether the state should be accessible in any Component and on the backend.
  67. Raises:
  68. ValueError: If the var_name is not a string.
  69. Returns:
  70. ClientStateVar
  71. """
  72. if var_name is None:
  73. var_name = get_unique_variable_name()
  74. if not isinstance(var_name, str):
  75. raise ValueError("var_name must be a string.")
  76. if default is NoValue:
  77. default_var = Var(_js_expr="")
  78. elif not isinstance(default, Var):
  79. default_var = LiteralVar.create(default)
  80. else:
  81. default_var = default
  82. setter_name = f"set{var_name.capitalize()}"
  83. hooks: dict[str, VarData | None] = {
  84. f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
  85. }
  86. imports = {
  87. "react": [ImportVar(tag="useState")],
  88. }
  89. if global_ref:
  90. hooks[f"{_client_state_ref(var_name)} = {var_name}"] = None
  91. hooks[f"{_client_state_ref(setter_name)} = {setter_name}"] = None
  92. imports.update(_refs_import)
  93. return cls(
  94. _js_expr="",
  95. _setter_name=setter_name,
  96. _getter_name=var_name,
  97. _global_ref=global_ref,
  98. _var_type=default_var._var_type,
  99. _var_data=VarData.merge(
  100. default_var._var_data,
  101. VarData(
  102. hooks=hooks,
  103. imports=imports,
  104. ),
  105. ),
  106. )
  107. @property
  108. def value(self) -> Var:
  109. """Get a placeholder for the Var.
  110. This property can only be rendered on the frontend.
  111. To access the value in a backend event handler, see `retrieve`.
  112. Returns:
  113. an accessor for the client state variable.
  114. """
  115. return (
  116. Var(
  117. _js_expr=(
  118. _client_state_ref(self._getter_name)
  119. if self._global_ref
  120. else self._getter_name
  121. )
  122. )
  123. .to(self._var_type)
  124. ._replace(
  125. merge_var_data=VarData( # type: ignore
  126. imports=_refs_import if self._global_ref else {}
  127. )
  128. )
  129. )
  130. def set_value(self, value: Any = NoValue) -> Var:
  131. """Set the value of the client state variable.
  132. This property can only be attached to a frontend event trigger.
  133. To set a value from a backend event handler, see `push`.
  134. Args:
  135. value: The value to set.
  136. Returns:
  137. A special EventChain Var which will set the value when triggered.
  138. """
  139. setter = (
  140. _client_state_ref(self._setter_name)
  141. if self._global_ref
  142. else self._setter_name
  143. )
  144. _var_data = VarData(imports=_refs_import if self._global_ref else {})
  145. if value is not NoValue:
  146. # This is a hack to make it work like an EventSpec taking an arg
  147. value_var = LiteralVar.create(value)
  148. _var_data = VarData.merge(_var_data, value_var._get_all_var_data())
  149. value_str = str(value_var)
  150. if value_str.startswith("_"):
  151. # remove patterns of ["*"] from the value_str using regex
  152. arg = re.sub(r"\[\".*\"\]", "", value_str)
  153. setter = f"(({arg}) => {setter}({value_str}))"
  154. else:
  155. setter = f"(() => {setter}({value_str}))"
  156. return Var(
  157. _js_expr=setter,
  158. _var_data=_var_data,
  159. ).to(FunctionVar, EventChain)
  160. @property
  161. def set(self) -> Var:
  162. """Set the value of the client state variable.
  163. This property can only be attached to a frontend event trigger.
  164. To set a value from a backend event handler, see `push`.
  165. Returns:
  166. A special EventChain Var which will set the value when triggered.
  167. """
  168. return self.set_value()
  169. def retrieve(
  170. self, callback: Union[EventHandler, Callable, None] = None
  171. ) -> EventSpec:
  172. """Pass the value of the client state variable to a backend EventHandler.
  173. The event handler must `yield` or `return` the EventSpec to trigger the event.
  174. Args:
  175. callback: The callback to pass the value to.
  176. Returns:
  177. An EventSpec which will retrieve the value when triggered.
  178. Raises:
  179. ValueError: If the ClientStateVar is not global.
  180. """
  181. if not self._global_ref:
  182. raise ValueError("ClientStateVar must be global to retrieve the value.")
  183. return run_script(_client_state_ref(self._getter_name), callback=callback)
  184. def push(self, value: Any) -> EventSpec:
  185. """Push a value to the client state variable from the backend.
  186. The event handler must `yield` or `return` the EventSpec to trigger the event.
  187. Args:
  188. value: The value to update.
  189. Returns:
  190. An EventSpec which will push the value when triggered.
  191. Raises:
  192. ValueError: If the ClientStateVar is not global.
  193. """
  194. if not self._global_ref:
  195. raise ValueError("ClientStateVar must be global to push the value.")
  196. value = Var.create(value)
  197. return run_script(f"{_client_state_ref(self._setter_name)}({value})")