Selaa lähdekoodia

introduce customizable breakpoints (#3568)

* introduce customizable breakpoints
Khaleel Al-Adhami 10 kuukautta sitten
vanhempi
säilyke
82a0d7f3fb

+ 1 - 0
reflex/__init__.py

@@ -214,6 +214,7 @@ COMPONENTS_CORE_MAPPING: dict = {
     "components.core.match": ["match"],
     "components.core.match": ["match"],
     "components.core.clipboard": ["clipboard"],
     "components.core.clipboard": ["clipboard"],
     "components.core.colors": ["color"],
     "components.core.colors": ["color"],
+    "components.core.breakpoints": ["breakpoints"],
     "components.core.responsive": [
     "components.core.responsive": [
         "desktop_only",
         "desktop_only",
         "mobile_and_tablet",
         "mobile_and_tablet",

+ 1 - 0
reflex/__init__.pyi

@@ -131,6 +131,7 @@ from .components.core.html import html as html
 from .components.core.match import match as match
 from .components.core.match import match as match
 from .components.core.clipboard import clipboard as clipboard
 from .components.core.clipboard import clipboard as clipboard
 from .components.core.colors import color as color
 from .components.core.colors import color as color
+from .components.core.breakpoints import breakpoints as breakpoints
 from .components.core.responsive import desktop_only as desktop_only
 from .components.core.responsive import desktop_only as desktop_only
 from .components.core.responsive import mobile_and_tablet as mobile_and_tablet
 from .components.core.responsive import mobile_and_tablet as mobile_and_tablet
 from .components.core.responsive import mobile_only as mobile_only
 from .components.core.responsive import mobile_only as mobile_only

+ 4 - 0
reflex/app.py

@@ -52,6 +52,7 @@ from reflex.components.component import (
     evaluate_style_namespaces,
     evaluate_style_namespaces,
 )
 )
 from reflex.components.core.banner import connection_pulser, connection_toaster
 from reflex.components.core.banner import connection_pulser, connection_toaster
+from reflex.components.core.breakpoints import set_breakpoints
 from reflex.components.core.client_side_routing import (
 from reflex.components.core.client_side_routing import (
     Default404Page,
     Default404Page,
     wait_for_client_redirect,
     wait_for_client_redirect,
@@ -245,6 +246,9 @@ class App(LifespanMixin, Base):
                 "rx.BaseState cannot be subclassed multiple times. use rx.State instead"
                 "rx.BaseState cannot be subclassed multiple times. use rx.State instead"
             )
             )
 
 
+        if "breakpoints" in self.style:
+            set_breakpoints(self.style.pop("breakpoints"))
+
         # Add middleware.
         # Add middleware.
         self.middleware.append(HydrateMiddleware())
         self.middleware.append(HydrateMiddleware())
 
 

+ 1 - 0
reflex/components/core/__init__.py

@@ -32,6 +32,7 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
         "match",
         "match",
         "Match",
         "Match",
     ],
     ],
+    "breakpoints": ["breakpoints", "set_breakpoints"],
     "responsive": [
     "responsive": [
         "desktop_only",
         "desktop_only",
         "mobile_and_tablet",
         "mobile_and_tablet",

+ 2 - 0
reflex/components/core/__init__.pyi

@@ -27,6 +27,8 @@ from .html import html as html
 from .html import Html as Html
 from .html import Html as Html
 from .match import match as match
 from .match import match as match
 from .match import Match as Match
 from .match import Match as Match
+from .breakpoints import breakpoints as breakpoints
+from .breakpoints import set_breakpoints as set_breakpoints
 from .responsive import desktop_only as desktop_only
 from .responsive import desktop_only as desktop_only
 from .responsive import mobile_and_tablet as mobile_and_tablet
 from .responsive import mobile_and_tablet as mobile_and_tablet
 from .responsive import mobile_only as mobile_only
 from .responsive import mobile_only as mobile_only

+ 66 - 0
reflex/components/core/breakpoints.py

@@ -0,0 +1,66 @@
+"""Breakpoints utility."""
+
+from typing import Optional, Tuple
+
+breakpoints_values = ["30em", "48em", "62em", "80em", "96em"]
+
+
+def set_breakpoints(values: Tuple[str, str, str, str, str]):
+    """Overwrite default breakpoint values.
+
+    Args:
+        values: CSS values in order defining the breakpoints of responsive layouts
+    """
+    breakpoints_values.clear()
+    breakpoints_values.extend(values)
+
+
+class Breakpoints(dict):
+    """A responsive styling helper."""
+
+    @classmethod
+    def create(
+        cls,
+        custom: Optional[dict] = None,
+        initial=None,
+        xs=None,
+        sm=None,
+        md=None,
+        lg=None,
+        xl=None,
+    ):
+        """Create a new instance of the helper. Only provide a custom component OR use named props.
+
+        Args:
+            custom: Custom mapping using CSS values or variables.
+            initial: Styling when in the inital width
+            xs: Styling when in the extra-small width
+            sm: Styling when in the small width
+            md: Styling when in the medium width
+            lg: Styling when in the large width
+            xl: Styling when in the extra-large width
+
+        Raises:
+            ValueError: If both custom and any other named parameters are provided.
+
+        Returns:
+            The responsive mapping.
+        """
+        thresholds = [initial, xs, sm, md, lg, xl]
+
+        if custom is not None:
+            if any((threshold is not None for threshold in thresholds)):
+                raise ValueError("Named props cannot be used with custom thresholds")
+
+            return Breakpoints(custom)
+        else:
+            return Breakpoints(
+                {
+                    k: v
+                    for k, v in zip(["0px", *breakpoints_values], thresholds)
+                    if v is not None
+                }
+            )
+
+
+breakpoints = Breakpoints.create

+ 24 - 13
reflex/style.py

@@ -5,6 +5,7 @@ from __future__ import annotations
 from typing import Any, Literal, Tuple, Type
 from typing import Any, Literal, Tuple, Type
 
 
 from reflex import constants
 from reflex import constants
+from reflex.components.core.breakpoints import Breakpoints, breakpoints_values
 from reflex.event import EventChain
 from reflex.event import EventChain
 from reflex.utils import format
 from reflex.utils import format
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
@@ -86,8 +87,6 @@ toggle_color_mode = _color_mode_var(
     _var_type=EventChain,
     _var_type=EventChain,
 )
 )
 
 
-breakpoints = ["0", "30em", "48em", "62em", "80em", "96em"]
-
 STYLE_PROP_SHORTHAND_MAPPING = {
 STYLE_PROP_SHORTHAND_MAPPING = {
     "paddingX": ("paddingInlineStart", "paddingInlineEnd"),
     "paddingX": ("paddingInlineStart", "paddingInlineEnd"),
     "paddingY": ("paddingTop", "paddingBottom"),
     "paddingY": ("paddingTop", "paddingBottom"),
@@ -100,16 +99,16 @@ STYLE_PROP_SHORTHAND_MAPPING = {
 }
 }
 
 
 
 
-def media_query(breakpoint_index: int):
+def media_query(breakpoint_expr: str):
     """Create a media query selector.
     """Create a media query selector.
 
 
     Args:
     Args:
-        breakpoint_index: The index of the breakpoint to use.
+        breakpoint_expr: The CSS expression representing the breakpoint.
 
 
     Returns:
     Returns:
         The media query selector used as a key in emotion css dict.
         The media query selector used as a key in emotion css dict.
     """
     """
-    return f"@media screen and (min-width: {breakpoints[breakpoint_index]})"
+    return f"@media screen and (min-width: {breakpoint_expr})"
 
 
 
 
 def convert_item(style_item: str | Var) -> tuple[str, VarData | None]:
 def convert_item(style_item: str | Var) -> tuple[str, VarData | None]:
@@ -189,6 +188,10 @@ def convert(style_dict):
             update_out_dict(return_val, keys)
             update_out_dict(return_val, keys)
         # Combine all the collected VarData instances.
         # Combine all the collected VarData instances.
         var_data = VarData.merge(var_data, new_var_data)
         var_data = VarData.merge(var_data, new_var_data)
+
+    if isinstance(style_dict, Breakpoints):
+        out = Breakpoints(out)
+
     return out, var_data
     return out, var_data
 
 
 
 
@@ -295,14 +298,22 @@ def format_as_emotion(style_dict: dict[str, Any]) -> Style | None:
 
 
     for orig_key, value in style_dict.items():
     for orig_key, value in style_dict.items():
         key = _format_emotion_style_pseudo_selector(orig_key)
         key = _format_emotion_style_pseudo_selector(orig_key)
-        if isinstance(value, list):
-            # Apply media queries from responsive value list.
-            mbps = {
-                media_query(bp): (
-                    bp_value if isinstance(bp_value, dict) else {key: bp_value}
-                )
-                for bp, bp_value in enumerate(value)
-            }
+        if isinstance(value, (Breakpoints, list)):
+            if isinstance(value, Breakpoints):
+                mbps = {
+                    media_query(bp): (
+                        bp_value if isinstance(bp_value, dict) else {key: bp_value}
+                    )
+                    for bp, bp_value in value.items()
+                }
+            else:
+                # Apply media queries from responsive value list.
+                mbps = {
+                    media_query([0, *breakpoints_values][bp]): (
+                        bp_value if isinstance(bp_value, dict) else {key: bp_value}
+                    )
+                    for bp, bp_value in enumerate(value)
+                }
             if key.startswith("&:"):
             if key.startswith("&:"):
                 emotion_style[key] = mbps
                 emotion_style[key] = mbps
             else:
             else: