client_state.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """Handle client side state with `useState`."""
  2. import dataclasses
  3. import sys
  4. from typing import Any, Callable, Optional, Type
  5. from reflex import constants
  6. from reflex.event import EventChain, EventHandler, EventSpec, call_script
  7. from reflex.utils.imports import ImportVar
  8. from reflex.vars import Var, VarData
  9. def _client_state_ref(var_name: str) -> str:
  10. """Get the ref path for a ClientStateVar.
  11. Args:
  12. var_name: The name of the variable.
  13. Returns:
  14. An accessor for ClientStateVar ref as a string.
  15. """
  16. return f"refs['_client_state_{var_name}']"
  17. @dataclasses.dataclass(
  18. eq=False,
  19. **{"slots": True} if sys.version_info >= (3, 10) else {},
  20. )
  21. class ClientStateVar(Var):
  22. """A Var that exists on the client via useState."""
  23. # The name of the var.
  24. _var_name: str = dataclasses.field()
  25. # Track the names of the getters and setters
  26. _setter_name: str = dataclasses.field()
  27. _getter_name: str = dataclasses.field()
  28. # The type of the var.
  29. _var_type: Type = dataclasses.field(default=Any)
  30. # Whether this is a local javascript variable.
  31. _var_is_local: bool = dataclasses.field(default=False)
  32. # Whether the var is a string literal.
  33. _var_is_string: bool = dataclasses.field(default=False)
  34. # _var_full_name should be prefixed with _var_state
  35. _var_full_name_needs_state_prefix: bool = dataclasses.field(default=False)
  36. # Extra metadata associated with the Var
  37. _var_data: Optional[VarData] = dataclasses.field(default=None)
  38. def __hash__(self) -> int:
  39. """Define a hash function for a var.
  40. Returns:
  41. The hash of the var.
  42. """
  43. return hash(
  44. (self._var_name, str(self._var_type), self._getter_name, self._setter_name)
  45. )
  46. @classmethod
  47. def create(cls, var_name, default=None) -> "ClientStateVar":
  48. """Create a local_state Var that can be accessed and updated on the client.
  49. The `ClientStateVar` should be included in the highest parent component
  50. that contains the components which will access and manipulate the client
  51. state. It has no visual rendering, including it ensures that the
  52. `useState` hook is called in the correct scope.
  53. To render the var in a component, use the `value` property.
  54. To update the var in a component, use the `set` property.
  55. To access the var in an event handler, use the `retrieve` method with
  56. `callback` set to the event handler which should receive the value.
  57. To update the var in an event handler, use the `push` method with the
  58. value to update.
  59. Args:
  60. var_name: The name of the variable.
  61. default: The default value of the variable.
  62. Returns:
  63. ClientStateVar
  64. """
  65. if default is None:
  66. default_var = Var.create_safe("", _var_is_local=False, _var_is_string=False)
  67. elif not isinstance(default, Var):
  68. default_var = Var.create_safe(default)
  69. else:
  70. default_var = default
  71. setter_name = f"set{var_name.capitalize()}"
  72. return cls(
  73. _var_name="",
  74. _setter_name=setter_name,
  75. _getter_name=var_name,
  76. _var_is_local=False,
  77. _var_is_string=False,
  78. _var_type=default_var._var_type,
  79. _var_data=VarData.merge(
  80. default_var._var_data,
  81. VarData( # type: ignore
  82. hooks={
  83. f"const [{var_name}, {setter_name}] = useState({default_var._var_name_unwrapped})": None,
  84. f"{_client_state_ref(var_name)} = {var_name}": None,
  85. f"{_client_state_ref(setter_name)} = {setter_name}": None,
  86. },
  87. imports={
  88. "react": {ImportVar(tag="useState", install=False)},
  89. f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
  90. },
  91. ),
  92. ),
  93. )
  94. @property
  95. def value(self) -> Var:
  96. """Get a placeholder for the Var.
  97. This property can only be rendered on the frontend.
  98. To access the value in a backend event handler, see `retrieve`.
  99. Returns:
  100. an accessor for the client state variable.
  101. """
  102. return (
  103. Var.create_safe(
  104. _client_state_ref(self._getter_name),
  105. _var_is_local=False,
  106. _var_is_string=False,
  107. )
  108. .to(self._var_type)
  109. ._replace(
  110. merge_var_data=VarData( # type: ignore
  111. imports={
  112. f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
  113. }
  114. )
  115. )
  116. )
  117. @property
  118. def set(self) -> Var:
  119. """Set the value of the client state variable.
  120. This property can only be attached to a frontend event trigger.
  121. To set a value from a backend event handler, see `push`.
  122. Returns:
  123. A special EventChain Var which will set the value when triggered.
  124. """
  125. return (
  126. Var.create_safe(
  127. _client_state_ref(self._setter_name),
  128. _var_is_local=False,
  129. _var_is_string=False,
  130. )
  131. .to(EventChain)
  132. ._replace(
  133. merge_var_data=VarData( # type: ignore
  134. imports={
  135. f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
  136. }
  137. )
  138. )
  139. )
  140. def retrieve(self, callback: EventHandler | Callable | None = None) -> EventSpec:
  141. """Pass the value of the client state variable to a backend EventHandler.
  142. The event handler must `yield` or `return` the EventSpec to trigger the event.
  143. Args:
  144. callback: The callback to pass the value to.
  145. Returns:
  146. An EventSpec which will retrieve the value when triggered.
  147. """
  148. return call_script(_client_state_ref(self._getter_name), callback=callback)
  149. def push(self, value: Any) -> EventSpec:
  150. """Push a value to the client state variable from the backend.
  151. The event handler must `yield` or `return` the EventSpec to trigger the event.
  152. Args:
  153. value: The value to update.
  154. Returns:
  155. An EventSpec which will push the value when triggered.
  156. """
  157. return call_script(f"{_client_state_ref(self._setter_name)}({value})")