Browse Source

Abstract color_mode related Var creation (#3533)

* StatefulComponent event handler useCallback dependency array

Ensure the useCallback dependency array for event handlers also handled
destructuring objects and the `...` rest parameter.

Avoid hooks that are not variable declarations.

* Abstract color_mode related Var creation

* Allow `set_color_mode` to take a parameter at compile time
* Update type hinting of `Var._replace` to indicate that it returns BaseVar

* color_mode_button with allow_system=True uses new set_color_mode API

`set_color_mode` is now a CallableVar and uses very similar logic internally,
so this bit of code can be replaced.

* Fix for pydantic v1.10.17
Masen Furer 11 months ago
parent
commit
958c4fa7f2

+ 4 - 4
reflex/base.py

@@ -6,12 +6,12 @@ import os
 from typing import TYPE_CHECKING, Any, List, Type
 
 try:
-    import pydantic.v1 as pydantic
+    import pydantic.v1.main as pydantic_main
     from pydantic.v1 import BaseModel
     from pydantic.v1.fields import ModelField
 except ModuleNotFoundError:
     if not TYPE_CHECKING:
-        import pydantic
+        import pydantic.main as pydantic_main
         from pydantic import BaseModel
         from pydantic.fields import ModelField  # type: ignore
 
@@ -45,10 +45,10 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None
 
 # monkeypatch pydantic validate_field_name method to skip validating
 # shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
-pydantic.main.validate_field_name = validate_field_name  # type: ignore
+pydantic_main.validate_field_name = validate_field_name  # type: ignore
 
 
-class Base(pydantic.BaseModel):  # pyright: ignore [reportUnboundVariable]
+class Base(BaseModel):  # pyright: ignore [reportUnboundVariable]
     """The base class subclassed by all Reflex classes.
 
     This class wraps Pydantic and provides common methods such as

+ 14 - 4
reflex/components/component.py

@@ -2121,10 +2121,20 @@ class StatefulComponent(BaseComponent):
         Returns:
             A list of var names created by the hook declaration.
         """
-        var_name = hook.partition("=")[0].strip().split(None, 1)[1].strip()
-        if var_name.startswith("["):
-            # Break up array destructuring.
-            return [v.strip() for v in var_name.strip("[]").split(",")]
+        # Ensure that the hook is a var declaration.
+        var_decl = hook.partition("=")[0].strip()
+        if not any(var_decl.startswith(kw) for kw in ["const ", "let ", "var "]):
+            return []
+
+        # Extract the var name from the declaration.
+        _, _, var_name = var_decl.partition(" ")
+        var_name = var_name.strip()
+
+        # Break up array and object destructuring if used.
+        if var_name.startswith("[") or var_name.startswith("{"):
+            return [
+                v.strip().replace("...", "") for v in var_name.strip("[]{}").split(",")
+            ]
         return [var_name]
 
     @classmethod

+ 2 - 2
reflex/components/core/cond.py

@@ -12,7 +12,7 @@ from reflex.constants.colors import Color
 from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
 from reflex.utils import format
 from reflex.utils.imports import ImportDict, ImportVar
-from reflex.vars import Var, VarData
+from reflex.vars import BaseVar, Var, VarData
 
 _IS_TRUE_IMPORT: ImportDict = {
     f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
@@ -118,7 +118,7 @@ def cond(condition: Any, c1: Component) -> Component: ...
 
 
 @overload
-def cond(condition: Any, c1: Any, c2: Any) -> Var: ...
+def cond(condition: Any, c1: Any, c2: Any) -> BaseVar: ...
 
 
 def cond(condition: Any, c1: Any, c2: Any = None):

+ 2 - 9
reflex/components/radix/themes/color_mode.py

@@ -25,7 +25,6 @@ from reflex.components.core.cond import Cond, color_mode_cond, cond
 from reflex.components.lucide.icon import Icon
 from reflex.components.radix.themes.components.dropdown_menu import dropdown_menu
 from reflex.components.radix.themes.components.switch import Switch
-from reflex.event import EventChain
 from reflex.style import LIGHT_COLOR_MODE, color_mode, set_color_mode, toggle_color_mode
 from reflex.utils import console
 from reflex.vars import BaseVar, Var
@@ -144,15 +143,9 @@ class ColorModeIconButton(IconButton):
         if allow_system:
 
             def color_mode_item(_color_mode):
-                setter = Var.create_safe(
-                    f'() => {set_color_mode._var_name}("{_color_mode}")',
-                    _var_is_string=False,
-                    _var_is_local=True,
-                    _var_data=set_color_mode._var_data,
+                return dropdown_menu.item(
+                    _color_mode.title(), on_click=set_color_mode(_color_mode)
                 )
-                setter._var_type = EventChain
-
-                return dropdown_menu.item(_color_mode.title(), on_click=setter)  # type: ignore
 
             return dropdown_menu.root(
                 dropdown_menu.trigger(

+ 0 - 1
reflex/components/radix/themes/color_mode.pyi

@@ -14,7 +14,6 @@ from reflex.components.core.cond import Cond, color_mode_cond, cond
 from reflex.components.lucide.icon import Icon
 from reflex.components.radix.themes.components.dropdown_menu import dropdown_menu
 from reflex.components.radix.themes.components.switch import Switch
-from reflex.event import EventChain
 from reflex.style import LIGHT_COLOR_MODE, color_mode, set_color_mode, toggle_color_mode
 from reflex.utils import console
 from reflex.vars import BaseVar, Var

+ 60 - 26
reflex/style.py

@@ -2,54 +2,88 @@
 
 from __future__ import annotations
 
-from typing import Any, Tuple
+from typing import Any, Literal, Tuple, Type
 
 from reflex import constants
 from reflex.event import EventChain
 from reflex.utils import format
 from reflex.utils.imports import ImportVar
-from reflex.vars import BaseVar, Var, VarData
+from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 VarData.update_forward_refs()  # Ensure all type definitions are resolved
 
 SYSTEM_COLOR_MODE: str = "system"
 LIGHT_COLOR_MODE: str = "light"
 DARK_COLOR_MODE: str = "dark"
+LiteralColorMode = Literal["system", "light", "dark"]
 
 # Reference the global ColorModeContext
 color_mode_imports = {
     f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
     "react": [ImportVar(tag="useContext")],
 }
-color_mode_toggle_hooks = {
-    f"const {{ {constants.ColorMode.RESOLVED_NAME}, {constants.ColorMode.TOGGLE} }} = useContext(ColorModeContext)": None,
-}
-color_mode_set_hooks = {
-    f"const {{ {constants.ColorMode.NAME}, {constants.ColorMode.RESOLVED_NAME}, {constants.ColorMode.TOGGLE}, {constants.ColorMode.SET} }} = useContext(ColorModeContext)": None,
-}
-color_mode_var_data = VarData(imports=color_mode_imports, hooks=color_mode_toggle_hooks)
+
+
+def _color_mode_var(_var_name: str, _var_type: Type = str) -> BaseVar:
+    """Create a Var that destructs the _var_name from ColorModeContext.
+
+    Args:
+        _var_name: The name of the variable to get from ColorModeContext.
+        _var_type: The type of the Var.
+
+    Returns:
+        The BaseVar for accessing _var_name from ColorModeContext.
+    """
+    return BaseVar(
+        _var_name=_var_name,
+        _var_type=_var_type,
+        _var_is_local=False,
+        _var_is_string=False,
+        _var_data=VarData(
+            imports=color_mode_imports,
+            hooks={f"const {{ {_var_name} }} = useContext(ColorModeContext)": None},
+        ),
+    )
+
+
+@CallableVar
+def set_color_mode(
+    new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None,
+) -> BaseVar[EventChain]:
+    """Create an EventChain Var that sets the color mode to a specific value.
+
+    Note: `set_color_mode` is not a real event and cannot be triggered from a
+    backend event handler.
+
+    Args:
+        new_color_mode: The color mode to set.
+
+    Returns:
+        The EventChain Var that can be passed to an event trigger.
+    """
+    base_setter = _color_mode_var(
+        _var_name=constants.ColorMode.SET,
+        _var_type=EventChain,
+    )
+    if new_color_mode is None:
+        return base_setter
+
+    if not isinstance(new_color_mode, Var):
+        new_color_mode = Var.create_safe(new_color_mode, _var_is_string=True)
+    return base_setter._replace(
+        _var_name=f"() => {base_setter._var_name}({new_color_mode._var_name_unwrapped})",
+        merge_var_data=new_color_mode._var_data,
+    )
+
+
 # Var resolves to the current color mode for the app ("light", "dark" or "system")
-color_mode = BaseVar(
-    _var_name=constants.ColorMode.NAME,
-    _var_type="str",
-    _var_data=color_mode_var_data,
-)
+color_mode = _color_mode_var(_var_name=constants.ColorMode.NAME)
 # Var resolves to the resolved color mode for the app ("light" or "dark")
-resolved_color_mode = BaseVar(
-    _var_name=constants.ColorMode.RESOLVED_NAME,
-    _var_type="str",
-    _var_data=color_mode_var_data,
-)
+resolved_color_mode = _color_mode_var(_var_name=constants.ColorMode.RESOLVED_NAME)
 # Var resolves to a function invocation that toggles the color mode
-toggle_color_mode = BaseVar(
+toggle_color_mode = _color_mode_var(
     _var_name=constants.ColorMode.TOGGLE,
     _var_type=EventChain,
-    _var_data=color_mode_var_data,
-)
-set_color_mode = BaseVar(
-    _var_name=constants.ColorMode.SET,
-    _var_type=EventChain,
-    _var_data=VarData(imports=color_mode_imports, hooks=color_mode_set_hooks),
 )
 
 breakpoints = ["0", "30em", "48em", "62em", "80em", "96em"]

+ 1 - 1
reflex/vars.py

@@ -482,7 +482,7 @@ class Var:
             self._var_name = _var_name
             self._var_data = VarData.merge(self._var_data, _var_data)
 
-    def _replace(self, merge_var_data=None, **kwargs: Any) -> Var:
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> BaseVar:
         """Make a copy of this Var with updated fields.
 
         Args:

+ 1 - 1
reflex/vars.pyi

@@ -59,7 +59,7 @@ class Var:
     ) -> Var: ...
     @classmethod
     def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
-    def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: ...
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> BaseVar: ...
     def equals(self, other: Var) -> bool: ...
     def to_string(self) -> Var: ...
     def __hash__(self) -> int: ...