Browse Source

[REF-2879] Make client_state work without global refs (local only) (#3379)

* Make client_state work without global refs (local only)

* client_state: if the default is str, mark _var_is_string=True

Ensure that a string default is not rendered literally

* add `to_int` as a Var operation

* Allow an event handler lambda to return a Var in some cases

If an lambda is passed to an event trigger and it returns a single Var, then
treat it like the Var was directly passed for the event trigger.

This allows ClientState.set_var to be used within a lambda.
Masen Furer 11 months ago
parent
commit
bb44d51f2f
5 changed files with 120 additions and 31 deletions
  1. 12 2
      reflex/components/component.py
  2. 11 5
      reflex/event.py
  3. 81 23
      reflex/experimental/client_state.py
  4. 8 1
      reflex/utils/format.py
  5. 8 0
      reflex/vars.py

+ 12 - 2
reflex/components/component.py

@@ -519,13 +519,23 @@ class Component(BaseComponent, ABC):
                     events.append(event)
                 elif isinstance(v, Callable):
                     # Call the lambda to get the event chain.
-                    events.extend(call_event_fn(v, args_spec))
+                    result = call_event_fn(v, args_spec)
+                    if isinstance(result, Var):
+                        raise ValueError(
+                            f"Invalid event chain: {v}. Cannot use a Var-returning "
+                            "lambda inside an EventChain list."
+                        )
+                    events.extend(result)
                 else:
                     raise ValueError(f"Invalid event: {v}")
 
         # If the input is a callable, create an event chain.
         elif isinstance(value, Callable):
-            events = call_event_fn(value, args_spec)
+            result = call_event_fn(value, args_spec)
+            if isinstance(result, Var):
+                # Recursively call this function if the lambda returned an EventChain Var.
+                return self._create_event_chain(args_spec, result)
+            events = result
 
         # Otherwise, raise an error.
         else:

+ 11 - 5
reflex/event.py

@@ -415,6 +415,8 @@ class FileUpload(Base):
                 )  # type: ignore
             else:
                 raise ValueError(f"{on_upload_progress} is not a valid event handler.")
+            if isinstance(events, Var):
+                raise ValueError(f"{on_upload_progress} cannot return a var {events}.")
             on_upload_progress_chain = EventChain(
                 events=events,
                 args_spec=self.on_upload_progress_args_spec,
@@ -831,19 +833,19 @@ def parse_args_spec(arg_spec: ArgsSpec):
     )
 
 
-def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec]:
+def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec] | Var:
     """Call a function to a list of event specs.
 
-    The function should return either a single EventSpec or a list of EventSpecs.
-    If the function takes in an arg, the arg will be passed to the function.
-    Otherwise, the function will be called with no args.
+    The function should return a single EventSpec, a list of EventSpecs, or a
+    single Var. If the function takes in an arg, the arg will be passed to the
+    function. Otherwise, the function will be called with no args.
 
     Args:
         fn: The function to call.
         arg: The argument to pass to the function.
 
     Returns:
-        The event specs from calling the function.
+        The event specs from calling the function or a Var.
 
     Raises:
         EventHandlerValueError: If the lambda has an invalid signature.
@@ -866,6 +868,10 @@ def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec]:
         else:
             raise EventHandlerValueError(f"Lambda {fn} must have 0 or 1 arguments.")
 
+    # If the function returns a Var, assume it's an EventChain and render it directly.
+    if isinstance(out, Var):
+        return out
+
     # Convert the output to a list.
     if not isinstance(out, List):
         out = [out]

+ 81 - 23
reflex/experimental/client_state.py

@@ -1,4 +1,5 @@
 """Handle client side state with `useState`."""
+from __future__ import annotations
 
 import dataclasses
 import sys
@@ -9,6 +10,13 @@ from reflex.event import EventChain, EventHandler, EventSpec, call_script
 from reflex.utils.imports import ImportVar
 from reflex.vars import Var, VarData
 
+NoValue = object()
+
+
+_refs_import = {
+    f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
+}
+
 
 def _client_state_ref(var_name: str) -> str:
     """Get the ref path for a ClientStateVar.
@@ -36,6 +44,9 @@ class ClientStateVar(Var):
     _setter_name: str = dataclasses.field()
     _getter_name: str = dataclasses.field()
 
+    # Whether to add the var and setter to the global `refs` object for use in any Component.
+    _global_ref: bool = dataclasses.field(default=True)
+
     # The type of the var.
     _var_type: Type = dataclasses.field(default=Any)
 
@@ -62,7 +73,9 @@ class ClientStateVar(Var):
         )
 
     @classmethod
-    def create(cls, var_name, default=None) -> "ClientStateVar":
+    def create(
+        cls, var_name: str, default: Any = NoValue, global_ref: bool = True
+    ) -> "ClientStateVar":
         """Create a local_state Var that can be accessed and updated on the client.
 
         The `ClientStateVar` should be included in the highest parent component
@@ -72,7 +85,7 @@ class ClientStateVar(Var):
 
         To render the var in a component, use the `value` property.
 
-        To update the var in a component, use the `set` property.
+        To update the var in a component, use the `set` property or `set_value` method.
 
         To access the var in an event handler, use the `retrieve` method with
         `callback` set to the event handler which should receive the value.
@@ -83,36 +96,45 @@ class ClientStateVar(Var):
         Args:
             var_name: The name of the variable.
             default: The default value of the variable.
+            global_ref: Whether the state should be accessible in any Component and on the backend.
 
         Returns:
             ClientStateVar
         """
-        if default is None:
+        assert isinstance(var_name, str), "var_name must be a string."
+        if default is NoValue:
             default_var = Var.create_safe("", _var_is_local=False, _var_is_string=False)
         elif not isinstance(default, Var):
-            default_var = Var.create_safe(default)
+            default_var = Var.create_safe(
+                default,
+                _var_is_string=isinstance(default, str),
+            )
         else:
             default_var = default
         setter_name = f"set{var_name.capitalize()}"
+        hooks = {
+            f"const [{var_name}, {setter_name}] = useState({default_var._var_name_unwrapped})": None,
+        }
+        imports = {
+            "react": [ImportVar(tag="useState")],
+        }
+        if global_ref:
+            hooks[f"{_client_state_ref(var_name)} = {var_name}"] = None
+            hooks[f"{_client_state_ref(setter_name)} = {setter_name}"] = None
+            imports.update(_refs_import)
         return cls(
             _var_name="",
             _setter_name=setter_name,
             _getter_name=var_name,
+            _global_ref=global_ref,
             _var_is_local=False,
             _var_is_string=False,
             _var_type=default_var._var_type,
             _var_data=VarData.merge(
                 default_var._var_data,
                 VarData(  # type: ignore
-                    hooks={
-                        f"const [{var_name}, {setter_name}] = useState({default_var._var_name_unwrapped})": None,
-                        f"{_client_state_ref(var_name)} = {var_name}": None,
-                        f"{_client_state_ref(setter_name)} = {setter_name}": None,
-                    },
-                    imports={
-                        "react": [ImportVar(tag="useState", install=False)],
-                        f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
-                    },
+                    hooks=hooks,
+                    imports=imports,
                 ),
             ),
         )
@@ -130,47 +152,73 @@ class ClientStateVar(Var):
         """
         return (
             Var.create_safe(
-                _client_state_ref(self._getter_name),
+                _client_state_ref(self._getter_name)
+                if self._global_ref
+                else self._getter_name,
                 _var_is_local=False,
                 _var_is_string=False,
             )
             .to(self._var_type)
             ._replace(
                 merge_var_data=VarData(  # type: ignore
-                    imports={
-                        f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
-                    }
+                    imports=_refs_import if self._global_ref else {}
                 )
             )
         )
 
-    @property
-    def set(self) -> Var:
+    def set_value(self, value: Any = NoValue) -> Var:
         """Set the value of the client state variable.
 
         This property can only be attached to a frontend event trigger.
 
         To set a value from a backend event handler, see `push`.
 
+        Args:
+            value: The value to set.
+
         Returns:
             A special EventChain Var which will set the value when triggered.
         """
+        setter = (
+            _client_state_ref(self._setter_name)
+            if self._global_ref
+            else self._setter_name
+        )
+        if value is not NoValue:
+            # This is a hack to make it work like an EventSpec taking an arg
+            value = Var.create_safe(value, _var_is_string=isinstance(value, str))
+            if not value._var_is_string and value._var_full_name.startswith("_"):
+                arg = value._var_name_unwrapped
+            else:
+                arg = ""
+            setter = f"({arg}) => {setter}({value._var_name_unwrapped})"
         return (
             Var.create_safe(
-                _client_state_ref(self._setter_name),
+                setter,
                 _var_is_local=False,
                 _var_is_string=False,
             )
             .to(EventChain)
             ._replace(
                 merge_var_data=VarData(  # type: ignore
-                    imports={
-                        f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
-                    }
+                    imports=_refs_import if self._global_ref else {}
                 )
             )
         )
 
+    @property
+    def set(self) -> Var:
+        """Set the value of the client state variable.
+
+        This property can only be attached to a frontend event trigger.
+
+        To set a value from a backend event handler, see `push`.
+
+        Returns:
+            A special EventChain Var which will set the value when triggered.
+        """
+        return self.set_value()
+
     def retrieve(
         self, callback: Union[EventHandler, Callable, None] = None
     ) -> EventSpec:
@@ -183,7 +231,12 @@ class ClientStateVar(Var):
 
         Returns:
             An EventSpec which will retrieve the value when triggered.
+
+        Raises:
+            ValueError: If the ClientStateVar is not global.
         """
+        if not self._global_ref:
+            raise ValueError("ClientStateVar must be global to retrieve the value.")
         return call_script(_client_state_ref(self._getter_name), callback=callback)
 
     def push(self, value: Any) -> EventSpec:
@@ -196,5 +249,10 @@ class ClientStateVar(Var):
 
         Returns:
             An EventSpec which will push the value when triggered.
+
+        Raises:
+            ValueError: If the ClientStateVar is not global.
         """
+        if not self._global_ref:
+            raise ValueError("ClientStateVar must be global to push the value.")
         return call_script(f"{_client_state_ref(self._setter_name)}({value})")

+ 8 - 1
reflex/utils/format.py

@@ -612,6 +612,9 @@ def format_queue_events(
 
     Returns:
         The compiled javascript callback to queue the given events on the frontend.
+
+    Raises:
+        ValueError: If a lambda function is given which returns a Var.
     """
     from reflex.event import (
         EventChain,
@@ -648,7 +651,11 @@ def format_queue_events(
         if isinstance(spec, (EventHandler, EventSpec)):
             specs = [call_event_handler(spec, args_spec or _default_args_spec)]
         elif isinstance(spec, type(lambda: None)):
-            specs = call_event_fn(spec, args_spec or _default_args_spec)
+            specs = call_event_fn(spec, args_spec or _default_args_spec)  # type: ignore
+            if isinstance(specs, Var):
+                raise ValueError(
+                    f"Invalid event spec: {specs}. Expected a list of EventSpecs."
+                )
         payloads.extend(format_event(s) for s in specs)
 
     # Return the final code snippet, expecting queueEvents, processEvent, and socket to be in scope.

+ 8 - 0
reflex/vars.py

@@ -552,6 +552,14 @@ class Var:
         fn = "JSON.stringify" if json else "String"
         return self.operation(fn=fn, type_=str)
 
+    def to_int(self) -> Var:
+        """Convert a var to an int.
+
+        Returns:
+            The parseInt var.
+        """
+        return self.operation(fn="parseInt", type_=int)
+
     def __hash__(self) -> int:
         """Define a hash function for a var.