client_state.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. """Handle client side state with `useState`."""
  2. from __future__ import annotations
  3. import dataclasses
  4. import sys
  5. from typing import Any, Callable, Union
  6. from reflex import constants
  7. from reflex.event import EventChain, EventHandler, EventSpec, call_script
  8. from reflex.utils.imports import ImportVar
  9. from reflex.vars import (
  10. VarData,
  11. get_unique_variable_name,
  12. )
  13. from reflex.vars.base import LiteralVar, Var
  14. from reflex.vars.function import FunctionVar
  15. NoValue = object()
  16. _refs_import = {
  17. f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
  18. }
  19. def _client_state_ref(var_name: str) -> str:
  20. """Get the ref path for a ClientStateVar.
  21. Args:
  22. var_name: The name of the variable.
  23. Returns:
  24. An accessor for ClientStateVar ref as a string.
  25. """
  26. return f"refs['_client_state_{var_name}']"
  27. @dataclasses.dataclass(
  28. eq=False,
  29. frozen=True,
  30. **{"slots": True} if sys.version_info >= (3, 10) else {},
  31. )
  32. class ClientStateVar(Var):
  33. """A Var that exists on the client via useState."""
  34. # Track the names of the getters and setters
  35. _setter_name: str = dataclasses.field(default="")
  36. _getter_name: str = dataclasses.field(default="")
  37. # Whether to add the var and setter to the global `refs` object for use in any Component.
  38. _global_ref: bool = dataclasses.field(default=True)
  39. def __hash__(self) -> int:
  40. """Define a hash function for a var.
  41. Returns:
  42. The hash of the var.
  43. """
  44. return hash(
  45. (self._js_expr, str(self._var_type), self._getter_name, self._setter_name)
  46. )
  47. @classmethod
  48. def create(
  49. cls,
  50. var_name: str | None = None,
  51. default: Any = NoValue,
  52. global_ref: bool = True,
  53. ) -> "ClientStateVar":
  54. """Create a local_state Var that can be accessed and updated on the client.
  55. The `ClientStateVar` should be included in the highest parent component
  56. that contains the components which will access and manipulate the client
  57. state. It has no visual rendering, including it ensures that the
  58. `useState` hook is called in the correct scope.
  59. To render the var in a component, use the `value` property.
  60. To update the var in a component, use the `set` property or `set_value` method.
  61. To access the var in an event handler, use the `retrieve` method with
  62. `callback` set to the event handler which should receive the value.
  63. To update the var in an event handler, use the `push` method with the
  64. value to update.
  65. Args:
  66. var_name: The name of the variable.
  67. default: The default value of the variable.
  68. global_ref: Whether the state should be accessible in any Component and on the backend.
  69. Returns:
  70. ClientStateVar
  71. """
  72. if var_name is None:
  73. var_name = get_unique_variable_name()
  74. assert isinstance(var_name, str), "var_name must be a string."
  75. if default is NoValue:
  76. default_var = Var(_js_expr="")
  77. elif not isinstance(default, Var):
  78. default_var = LiteralVar.create(default)
  79. else:
  80. default_var = default
  81. setter_name = f"set{var_name.capitalize()}"
  82. hooks = {
  83. f"const [{var_name}, {setter_name}] = useState({str(default_var)})": None,
  84. }
  85. imports = {
  86. "react": [ImportVar(tag="useState")],
  87. }
  88. if global_ref:
  89. hooks[f"{_client_state_ref(var_name)} = {var_name}"] = None
  90. hooks[f"{_client_state_ref(setter_name)} = {setter_name}"] = None
  91. imports.update(_refs_import)
  92. return cls(
  93. _js_expr="",
  94. _setter_name=setter_name,
  95. _getter_name=var_name,
  96. _global_ref=global_ref,
  97. _var_type=default_var._var_type,
  98. _var_data=VarData.merge(
  99. default_var._var_data,
  100. VarData(
  101. hooks=hooks,
  102. imports=imports,
  103. ),
  104. ),
  105. )
  106. @property
  107. def value(self) -> Var:
  108. """Get a placeholder for the Var.
  109. This property can only be rendered on the frontend.
  110. To access the value in a backend event handler, see `retrieve`.
  111. Returns:
  112. an accessor for the client state variable.
  113. """
  114. return (
  115. Var(
  116. _js_expr=(
  117. _client_state_ref(self._getter_name)
  118. if self._global_ref
  119. else self._getter_name
  120. )
  121. )
  122. .to(self._var_type)
  123. ._replace(
  124. merge_var_data=VarData( # type: ignore
  125. imports=_refs_import if self._global_ref else {}
  126. )
  127. )
  128. )
  129. def set_value(self, value: Any = NoValue) -> Var:
  130. """Set the value of the client state variable.
  131. This property can only be attached to a frontend event trigger.
  132. To set a value from a backend event handler, see `push`.
  133. Args:
  134. value: The value to set.
  135. Returns:
  136. A special EventChain Var which will set the value when triggered.
  137. """
  138. setter = (
  139. _client_state_ref(self._setter_name)
  140. if self._global_ref
  141. else self._setter_name
  142. )
  143. if value is not NoValue:
  144. import re
  145. # This is a hack to make it work like an EventSpec taking an arg
  146. value_str = str(LiteralVar.create(value))
  147. # remove patterns of ["*"] from the value_str using regex
  148. arg = re.sub(r"\[\".*\"\]", "", value_str)
  149. setter = f"({arg}) => {setter}({str(value)})"
  150. return Var(
  151. _js_expr=setter,
  152. _var_data=VarData(imports=_refs_import if self._global_ref else {}),
  153. ).to(FunctionVar, EventChain)
  154. @property
  155. def set(self) -> Var:
  156. """Set the value of the client state variable.
  157. This property can only be attached to a frontend event trigger.
  158. To set a value from a backend event handler, see `push`.
  159. Returns:
  160. A special EventChain Var which will set the value when triggered.
  161. """
  162. return self.set_value()
  163. def retrieve(
  164. self, callback: Union[EventHandler, Callable, None] = None
  165. ) -> EventSpec:
  166. """Pass the value of the client state variable to a backend EventHandler.
  167. The event handler must `yield` or `return` the EventSpec to trigger the event.
  168. Args:
  169. callback: The callback to pass the value to.
  170. Returns:
  171. An EventSpec which will retrieve the value when triggered.
  172. Raises:
  173. ValueError: If the ClientStateVar is not global.
  174. """
  175. if not self._global_ref:
  176. raise ValueError("ClientStateVar must be global to retrieve the value.")
  177. return call_script(_client_state_ref(self._getter_name), callback=callback)
  178. def push(self, value: Any) -> EventSpec:
  179. """Push a value to the client state variable from the backend.
  180. The event handler must `yield` or `return` the EventSpec to trigger the event.
  181. Args:
  182. value: The value to update.
  183. Returns:
  184. An EventSpec which will push the value when triggered.
  185. Raises:
  186. ValueError: If the ClientStateVar is not global.
  187. """
  188. if not self._global_ref:
  189. raise ValueError("ClientStateVar must be global to push the value.")
  190. return call_script(f"{_client_state_ref(self._setter_name)}({value})")