Преглед на файлове

Abstract color_mode related Var creation (#3533)

* StatefulComponent event handler useCallback dependency array

Ensure the useCallback dependency array for event handlers also handled
destructuring objects and the `...` rest parameter.

Avoid hooks that are not variable declarations.

* Abstract color_mode related Var creation

* Allow `set_color_mode` to take a parameter at compile time
* Update type hinting of `Var._replace` to indicate that it returns BaseVar

* color_mode_button with allow_system=True uses new set_color_mode API

`set_color_mode` is now a CallableVar and uses very similar logic internally,
so this bit of code can be replaced.

* Fix for pydantic v1.10.17
Masen Furer преди 11 месеца
родител
ревизия
958c4fa7f2

+ 4 - 4
reflex/base.py

@@ -6,12 +6,12 @@ import os
 from typing import TYPE_CHECKING, Any, List, Type
 from typing import TYPE_CHECKING, Any, List, Type
 
 
 try:
 try:
-    import pydantic.v1 as pydantic
+    import pydantic.v1.main as pydantic_main
     from pydantic.v1 import BaseModel
     from pydantic.v1 import BaseModel
     from pydantic.v1.fields import ModelField
     from pydantic.v1.fields import ModelField
 except ModuleNotFoundError:
 except ModuleNotFoundError:
     if not TYPE_CHECKING:
     if not TYPE_CHECKING:
-        import pydantic
+        import pydantic.main as pydantic_main
         from pydantic import BaseModel
         from pydantic import BaseModel
         from pydantic.fields import ModelField  # type: ignore
         from pydantic.fields import ModelField  # type: ignore
 
 
@@ -45,10 +45,10 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None
 
 
 # monkeypatch pydantic validate_field_name method to skip validating
 # monkeypatch pydantic validate_field_name method to skip validating
 # shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
 # shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
-pydantic.main.validate_field_name = validate_field_name  # type: ignore
+pydantic_main.validate_field_name = validate_field_name  # type: ignore
 
 
 
 
-class Base(pydantic.BaseModel):  # pyright: ignore [reportUnboundVariable]
+class Base(BaseModel):  # pyright: ignore [reportUnboundVariable]
     """The base class subclassed by all Reflex classes.
     """The base class subclassed by all Reflex classes.
 
 
     This class wraps Pydantic and provides common methods such as
     This class wraps Pydantic and provides common methods such as

+ 14 - 4
reflex/components/component.py

@@ -2121,10 +2121,20 @@ class StatefulComponent(BaseComponent):
         Returns:
         Returns:
             A list of var names created by the hook declaration.
             A list of var names created by the hook declaration.
         """
         """
-        var_name = hook.partition("=")[0].strip().split(None, 1)[1].strip()
-        if var_name.startswith("["):
-            # Break up array destructuring.
-            return [v.strip() for v in var_name.strip("[]").split(",")]
+        # Ensure that the hook is a var declaration.
+        var_decl = hook.partition("=")[0].strip()
+        if not any(var_decl.startswith(kw) for kw in ["const ", "let ", "var "]):
+            return []
+
+        # Extract the var name from the declaration.
+        _, _, var_name = var_decl.partition(" ")
+        var_name = var_name.strip()
+
+        # Break up array and object destructuring if used.
+        if var_name.startswith("[") or var_name.startswith("{"):
+            return [
+                v.strip().replace("...", "") for v in var_name.strip("[]{}").split(",")
+            ]
         return [var_name]
         return [var_name]
 
 
     @classmethod
     @classmethod

+ 2 - 2
reflex/components/core/cond.py

@@ -12,7 +12,7 @@ from reflex.constants.colors import Color
 from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
 from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
 from reflex.utils import format
 from reflex.utils import format
 from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.imports import ImportDict, ImportVar
-from reflex.vars import Var, VarData
+from reflex.vars import BaseVar, Var, VarData
 
 
 _IS_TRUE_IMPORT: ImportDict = {
 _IS_TRUE_IMPORT: ImportDict = {
     f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
     f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
@@ -118,7 +118,7 @@ def cond(condition: Any, c1: Component) -> Component: ...
 
 
 
 
 @overload
 @overload
-def cond(condition: Any, c1: Any, c2: Any) -> Var: ...
+def cond(condition: Any, c1: Any, c2: Any) -> BaseVar: ...
 
 
 
 
 def cond(condition: Any, c1: Any, c2: Any = None):
 def cond(condition: Any, c1: Any, c2: Any = None):

+ 2 - 9
reflex/components/radix/themes/color_mode.py

@@ -25,7 +25,6 @@ from reflex.components.core.cond import Cond, color_mode_cond, cond
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
 from reflex.components.radix.themes.components.dropdown_menu import dropdown_menu
 from reflex.components.radix.themes.components.dropdown_menu import dropdown_menu
 from reflex.components.radix.themes.components.switch import Switch
 from reflex.components.radix.themes.components.switch import Switch
-from reflex.event import EventChain
 from reflex.style import LIGHT_COLOR_MODE, color_mode, set_color_mode, toggle_color_mode
 from reflex.style import LIGHT_COLOR_MODE, color_mode, set_color_mode, toggle_color_mode
 from reflex.utils import console
 from reflex.utils import console
 from reflex.vars import BaseVar, Var
 from reflex.vars import BaseVar, Var
@@ -144,15 +143,9 @@ class ColorModeIconButton(IconButton):
         if allow_system:
         if allow_system:
 
 
             def color_mode_item(_color_mode):
             def color_mode_item(_color_mode):
-                setter = Var.create_safe(
-                    f'() => {set_color_mode._var_name}("{_color_mode}")',
-                    _var_is_string=False,
-                    _var_is_local=True,
-                    _var_data=set_color_mode._var_data,
+                return dropdown_menu.item(
+                    _color_mode.title(), on_click=set_color_mode(_color_mode)
                 )
                 )
-                setter._var_type = EventChain
-
-                return dropdown_menu.item(_color_mode.title(), on_click=setter)  # type: ignore
 
 
             return dropdown_menu.root(
             return dropdown_menu.root(
                 dropdown_menu.trigger(
                 dropdown_menu.trigger(

+ 0 - 1
reflex/components/radix/themes/color_mode.pyi

@@ -14,7 +14,6 @@ from reflex.components.core.cond import Cond, color_mode_cond, cond
 from reflex.components.lucide.icon import Icon
 from reflex.components.lucide.icon import Icon
 from reflex.components.radix.themes.components.dropdown_menu import dropdown_menu
 from reflex.components.radix.themes.components.dropdown_menu import dropdown_menu
 from reflex.components.radix.themes.components.switch import Switch
 from reflex.components.radix.themes.components.switch import Switch
-from reflex.event import EventChain
 from reflex.style import LIGHT_COLOR_MODE, color_mode, set_color_mode, toggle_color_mode
 from reflex.style import LIGHT_COLOR_MODE, color_mode, set_color_mode, toggle_color_mode
 from reflex.utils import console
 from reflex.utils import console
 from reflex.vars import BaseVar, Var
 from reflex.vars import BaseVar, Var

+ 60 - 26
reflex/style.py

@@ -2,54 +2,88 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Any, Tuple
+from typing import Any, Literal, Tuple, Type
 
 
 from reflex import constants
 from reflex import constants
 from reflex.event import EventChain
 from reflex.event import EventChain
 from reflex.utils import format
 from reflex.utils import format
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
-from reflex.vars import BaseVar, Var, VarData
+from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 
 VarData.update_forward_refs()  # Ensure all type definitions are resolved
 VarData.update_forward_refs()  # Ensure all type definitions are resolved
 
 
 SYSTEM_COLOR_MODE: str = "system"
 SYSTEM_COLOR_MODE: str = "system"
 LIGHT_COLOR_MODE: str = "light"
 LIGHT_COLOR_MODE: str = "light"
 DARK_COLOR_MODE: str = "dark"
 DARK_COLOR_MODE: str = "dark"
+LiteralColorMode = Literal["system", "light", "dark"]
 
 
 # Reference the global ColorModeContext
 # Reference the global ColorModeContext
 color_mode_imports = {
 color_mode_imports = {
     f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
     f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
     "react": [ImportVar(tag="useContext")],
     "react": [ImportVar(tag="useContext")],
 }
 }
-color_mode_toggle_hooks = {
-    f"const {{ {constants.ColorMode.RESOLVED_NAME}, {constants.ColorMode.TOGGLE} }} = useContext(ColorModeContext)": None,
-}
-color_mode_set_hooks = {
-    f"const {{ {constants.ColorMode.NAME}, {constants.ColorMode.RESOLVED_NAME}, {constants.ColorMode.TOGGLE}, {constants.ColorMode.SET} }} = useContext(ColorModeContext)": None,
-}
-color_mode_var_data = VarData(imports=color_mode_imports, hooks=color_mode_toggle_hooks)
+
+
+def _color_mode_var(_var_name: str, _var_type: Type = str) -> BaseVar:
+    """Create a Var that destructs the _var_name from ColorModeContext.
+
+    Args:
+        _var_name: The name of the variable to get from ColorModeContext.
+        _var_type: The type of the Var.
+
+    Returns:
+        The BaseVar for accessing _var_name from ColorModeContext.
+    """
+    return BaseVar(
+        _var_name=_var_name,
+        _var_type=_var_type,
+        _var_is_local=False,
+        _var_is_string=False,
+        _var_data=VarData(
+            imports=color_mode_imports,
+            hooks={f"const {{ {_var_name} }} = useContext(ColorModeContext)": None},
+        ),
+    )
+
+
+@CallableVar
+def set_color_mode(
+    new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None,
+) -> BaseVar[EventChain]:
+    """Create an EventChain Var that sets the color mode to a specific value.
+
+    Note: `set_color_mode` is not a real event and cannot be triggered from a
+    backend event handler.
+
+    Args:
+        new_color_mode: The color mode to set.
+
+    Returns:
+        The EventChain Var that can be passed to an event trigger.
+    """
+    base_setter = _color_mode_var(
+        _var_name=constants.ColorMode.SET,
+        _var_type=EventChain,
+    )
+    if new_color_mode is None:
+        return base_setter
+
+    if not isinstance(new_color_mode, Var):
+        new_color_mode = Var.create_safe(new_color_mode, _var_is_string=True)
+    return base_setter._replace(
+        _var_name=f"() => {base_setter._var_name}({new_color_mode._var_name_unwrapped})",
+        merge_var_data=new_color_mode._var_data,
+    )
+
+
 # Var resolves to the current color mode for the app ("light", "dark" or "system")
 # Var resolves to the current color mode for the app ("light", "dark" or "system")
-color_mode = BaseVar(
-    _var_name=constants.ColorMode.NAME,
-    _var_type="str",
-    _var_data=color_mode_var_data,
-)
+color_mode = _color_mode_var(_var_name=constants.ColorMode.NAME)
 # Var resolves to the resolved color mode for the app ("light" or "dark")
 # Var resolves to the resolved color mode for the app ("light" or "dark")
-resolved_color_mode = BaseVar(
-    _var_name=constants.ColorMode.RESOLVED_NAME,
-    _var_type="str",
-    _var_data=color_mode_var_data,
-)
+resolved_color_mode = _color_mode_var(_var_name=constants.ColorMode.RESOLVED_NAME)
 # Var resolves to a function invocation that toggles the color mode
 # Var resolves to a function invocation that toggles the color mode
-toggle_color_mode = BaseVar(
+toggle_color_mode = _color_mode_var(
     _var_name=constants.ColorMode.TOGGLE,
     _var_name=constants.ColorMode.TOGGLE,
     _var_type=EventChain,
     _var_type=EventChain,
-    _var_data=color_mode_var_data,
-)
-set_color_mode = BaseVar(
-    _var_name=constants.ColorMode.SET,
-    _var_type=EventChain,
-    _var_data=VarData(imports=color_mode_imports, hooks=color_mode_set_hooks),
 )
 )
 
 
 breakpoints = ["0", "30em", "48em", "62em", "80em", "96em"]
 breakpoints = ["0", "30em", "48em", "62em", "80em", "96em"]

+ 1 - 1
reflex/vars.py

@@ -482,7 +482,7 @@ class Var:
             self._var_name = _var_name
             self._var_name = _var_name
             self._var_data = VarData.merge(self._var_data, _var_data)
             self._var_data = VarData.merge(self._var_data, _var_data)
 
 
-    def _replace(self, merge_var_data=None, **kwargs: Any) -> Var:
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> BaseVar:
         """Make a copy of this Var with updated fields.
         """Make a copy of this Var with updated fields.
 
 
         Args:
         Args:

+ 1 - 1
reflex/vars.pyi

@@ -59,7 +59,7 @@ class Var:
     ) -> Var: ...
     ) -> Var: ...
     @classmethod
     @classmethod
     def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
     def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
-    def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: ...
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> BaseVar: ...
     def equals(self, other: Var) -> bool: ...
     def equals(self, other: Var) -> bool: ...
     def to_string(self) -> Var: ...
     def to_string(self) -> Var: ...
     def __hash__(self) -> int: ...
     def __hash__(self) -> int: ...