Kaynağa Gözat

Merge remote-tracking branch 'origin/main' into reflex-0.4.0

Masen Furer 1 yıl önce
ebeveyn
işleme
d28109f4c4

+ 4 - 12
reflex/components/component.py

@@ -199,7 +199,6 @@ class Component(BaseComponent, ABC):
 
         Raises:
             TypeError: If an invalid prop is passed.
-            ValueError: If a prop value is invalid.
         """
         # Set the id and children initially.
         children = kwargs.get("children", [])
@@ -249,17 +248,10 @@ class Component(BaseComponent, ABC):
                         raise TypeError
 
                     expected_type = fields[key].outer_type_.__args__[0]
-
-                    if (
-                        types.is_literal(expected_type)
-                        and value not in expected_type.__args__
-                    ):
-                        allowed_values = expected_type.__args__
-                        if value not in allowed_values:
-                            raise ValueError(
-                                f"prop value for {key} of the `{type(self).__name__}` component should be one of the following: {','.join(allowed_values)}. Got '{value}' instead"
-                            )
-
+                    # validate literal fields.
+                    types.validate_literal(
+                        key, value, expected_type, type(self).__name__
+                    )
                     # Get the passed type and the var type.
                     passed_type = kwargs[key]._var_type
                     expected_type = (

+ 3 - 2
reflex/components/core/banner.py

@@ -34,7 +34,8 @@ class WebsocketTargetURL(Bare):
 
     def _get_imports(self) -> imports.ImportDict:
         return {
-            "/utils/state.js": [imports.ImportVar(tag="getEventURL")],
+            "/utils/state.js": [imports.ImportVar(tag="getBackendURL")],
+            "/env.json": [imports.ImportVar(tag="env", is_default=True)],
         }
 
     @classmethod
@@ -44,7 +45,7 @@ class WebsocketTargetURL(Bare):
         Returns:
             The websocket target URL component.
         """
-        return super().create(contents="{getEventURL().href}")
+        return super().create(contents="{getBackendURL(env.EVENT).href}")
 
 
 def default_connection_error() -> list[str | Var | Component]:

+ 5 - 6
reflex/components/core/colors.py

@@ -1,13 +1,12 @@
 """The colors used in Reflex are a wrapper around https://www.radix-ui.com/colors."""
 
 from reflex.constants.colors import Color, ColorType, ShadeType
+from reflex.utils.types import validate_parameter_literals
+from reflex.vars import Var
 
 
-def color(
-    color: ColorType,
-    shade: ShadeType = 7,
-    alpha: bool = False,
-) -> Color:
+@validate_parameter_literals
+def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Var:
     """Create a color object.
 
     Args:
@@ -18,4 +17,4 @@ def color(
     Returns:
         The color object.
     """
-    return Color(color, shade, alpha)
+    return Var.create(Color(color, shade, alpha))._replace(_var_is_string=True)  # type: ignore

+ 16 - 0
reflex/reflex.py

@@ -100,6 +100,9 @@ def _init(
     # Migrate Pynecone projects to Reflex.
     prerequisites.migrate_to_reflex()
 
+    if prerequisites.should_show_rx_chakra_migration_instructions():
+        prerequisites.show_rx_chakra_migration_instructions()
+
     # Initialize the .gitignore.
     prerequisites.initialize_gitignore()
 
@@ -336,6 +339,7 @@ def logout(
 
 
 db_cli = typer.Typer()
+script_cli = typer.Typer()
 
 
 def _skip_compile():
@@ -414,6 +418,17 @@ def makemigrations(
             )
 
 
+@script_cli.command(
+    name="keep-chakra",
+    help="Change all rx.<component> references to rx.chakra.<component>, to preserve Chakra UI usage.",
+)
+def keep_chakra():
+    """Change all rx.<component> references to rx.chakra.<component>, to preserve Chakra UI usage."""
+    from reflex.utils import prerequisites
+
+    prerequisites.migrate_to_rx_chakra()
+
+
 @cli.command()
 def deploy(
     key: Optional[str] = typer.Option(
@@ -555,6 +570,7 @@ def demo(
 
 
 cli.add_typer(db_cli, name="db", help="Subcommands for managing the database schema.")
+cli.add_typer(script_cli, name="script", help="Subcommands running helper scripts.")
 cli.add_typer(
     deployments_cli,
     name="deployments",

+ 8 - 12
reflex/utils/format.py

@@ -253,24 +253,20 @@ def format_cond(
     # Use Python truthiness.
     cond = f"isTrue({cond})"
 
+    def create_var(cond_part):
+        return Var.create_safe(cond_part, _var_is_string=type(cond_part) is str)
+
     # Format prop conds.
     if is_prop:
-        if not isinstance(true_value, Var):
-            true_value = Var.create_safe(
-                true_value,
-                _var_is_string=type(true_value) is str,
-            )
+        true_value = create_var(true_value)
         prop1 = true_value._replace(
             _var_is_local=True,
         )
-        if not isinstance(false_value, Var):
-            false_value = Var.create_safe(
-                false_value,
-                _var_is_string=type(false_value) is str,
-            )
+
+        false_value = create_var(false_value)
         prop2 = false_value._replace(_var_is_local=True)
-        prop1, prop2 = str(prop1), str(prop2)  # avoid f-string semantics for Var
-        return f"{cond} ? {prop1} : {prop2}".replace("{", "").replace("}", "")
+        # unwrap '{}' to avoid f-string semantics for Var
+        return f"{cond} ? {prop1._var_name_unwrapped} : {prop2._var_name_unwrapped}"
 
     # Format component conds.
     return wrap(f"{cond} ? {true_value} : {false_value}", "{")

+ 106 - 0
reflex/utils/prerequisites.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import glob
 import importlib
+import inspect
 import json
 import os
 import platform
@@ -25,6 +26,7 @@ from alembic.util.exc import CommandError
 from packaging import version
 from redis.asyncio import Redis
 
+import reflex
 from reflex import constants, model
 from reflex.compiler import templates
 from reflex.config import Config, get_config
@@ -938,6 +940,110 @@ def prompt_for_template() -> constants.Templates.Kind:
     return constants.Templates.Kind(template)
 
 
+def should_show_rx_chakra_migration_instructions() -> bool:
+    """Should we show the migration instructions for rx.chakra.* => rx.*?.
+
+    Returns:
+        bool: True if we should show the migration instructions.
+    """
+    if os.getenv("REFLEX_PROMPT_MIGRATE_TO_RX_CHAKRA") == "yes":
+        return True
+
+    with open(constants.Dirs.REFLEX_JSON, "r") as f:
+        data = json.load(f)
+        existing_init_reflex_version = data.get("version", None)
+
+    if existing_init_reflex_version is None:
+        # They clone a reflex app from git for the first time.
+        # That app may or may not be 0.4 compatible.
+        # So let's just show these instructions THIS TIME.
+        return True
+
+    if constants.Reflex.VERSION < "0.4":
+        return False
+    else:
+        return existing_init_reflex_version < "0.4"
+
+
+def show_rx_chakra_migration_instructions():
+    """Show the migration instructions for rx.chakra.* => rx.*."""
+    console.log(
+        "Prior to reflex 0.4.0, rx.* components are based on Chakra UI. They are now based on Radix UI. To stick to Chakra UI, use rx.chakra.*."
+    )
+    console.log("")
+    console.log(
+        "[bold]Run `reflex script keep-chakra` to automatically update your app."
+    )
+    console.log("")
+    console.log("For more details, please see https://TODO")  # TODO add link to docs
+
+
+def migrate_to_rx_chakra():
+    """Migrate rx.button => r.chakra.button, etc."""
+    file_pattern = os.path.join(get_config().app_name, "**/*.py")
+    file_list = glob.glob(file_pattern, recursive=True)
+
+    # Populate with all rx.<x> components that have been moved to rx.chakra.<x>
+    patterns = {
+        rf"\brx\.{name}\b": f"rx.chakra.{name}"
+        for name in _get_rx_chakra_component_to_migrate()
+    }
+
+    for file_path in file_list:
+        with FileInput(file_path, inplace=True) as file:
+            for _line_num, line in enumerate(file):
+                for old, new in patterns.items():
+                    line = re.sub(old, new, line)
+                print(line, end="")
+
+
+def _get_rx_chakra_component_to_migrate() -> set[str]:
+    from reflex.components import ChakraComponent
+
+    rx_chakra_names = set(dir(reflex.chakra))
+
+    names_to_migrate = set()
+    whitelist = {
+        "CodeBlock",
+        "ColorModeIcon",
+        "MultiSelect",
+        "MultiSelectOption",
+        "base",
+        "code_block",
+        "color_mode_cond",
+        "color_mode_icon",
+        "multi_select",
+        "multi_select_option",
+    }
+    for rx_chakra_name in sorted(rx_chakra_names):
+        if rx_chakra_name.startswith("_"):
+            continue
+
+        rx_chakra_object = getattr(reflex.chakra, rx_chakra_name)
+        try:
+            if (
+                inspect.ismethod(rx_chakra_object)
+                and inspect.isclass(rx_chakra_object.__self__)
+                and issubclass(rx_chakra_object.__self__, ChakraComponent)
+            ):
+                names_to_migrate.add(rx_chakra_name)
+
+            elif inspect.isclass(rx_chakra_object) and issubclass(
+                rx_chakra_object, ChakraComponent
+            ):
+                names_to_migrate.add(rx_chakra_name)
+                pass
+            else:
+                # For the given rx.chakra.<x>, does rx.<x> exist?
+                # And of these, should we include in migration?
+                if hasattr(reflex, rx_chakra_name) and rx_chakra_name in whitelist:
+                    names_to_migrate.add(rx_chakra_name)
+
+        except Exception:
+            raise
+    return names_to_migrate
+
+
 def migrate_to_reflex():
     """Migration from Pynecone to Reflex."""
     # Check if the old config file exists.

+ 65 - 0
reflex/utils/types.py

@@ -3,7 +3,9 @@
 from __future__ import annotations
 
 import contextlib
+import inspect
 import types
+from functools import wraps
 from typing import (
     Any,
     Callable,
@@ -330,6 +332,69 @@ def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
     return type_ in allowed_types
 
 
+def validate_literal(key: str, value: Any, expected_type: Type, comp_name: str):
+    """Check that a value is a valid literal.
+
+    Args:
+        key: The prop name.
+        value: The prop value to validate.
+        expected_type: The expected type(literal type).
+        comp_name: Name of the component.
+
+    Raises:
+        ValueError: When the value is not a valid literal.
+    """
+    from reflex.vars import Var
+
+    if (
+        is_literal(expected_type)
+        and not isinstance(value, Var)  # validating vars is not supported yet.
+        and value not in expected_type.__args__
+    ):
+        allowed_values = expected_type.__args__
+        if value not in allowed_values:
+            value_str = ",".join(
+                [str(v) if not isinstance(v, str) else f"'{v}'" for v in allowed_values]
+            )
+            raise ValueError(
+                f"prop value for {str(key)} of the `{comp_name}` component should be one of the following: {value_str}. Got '{value}' instead"
+            )
+
+
+def validate_parameter_literals(func):
+    """Decorator to check that the arguments passed to a function
+    correspond to the correct function parameter if it (the parameter)
+    is a literal type.
+
+    Args:
+        func: The function to validate.
+
+    Returns:
+        The wrapper function.
+    """
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        func_params = list(inspect.signature(func).parameters.items())
+        annotations = {param[0]: param[1].annotation for param in func_params}
+
+        # validate args
+        for param, arg in zip(annotations.keys(), args):
+            if annotations[param] is inspect.Parameter.empty:
+                continue
+            validate_literal(param, arg, annotations[param], func.__name__)
+
+        # validate kwargs.
+        for key, value in kwargs.items():
+            annotation = annotations.get(key)
+            if not annotation or annotation is inspect.Parameter.empty:
+                continue
+            validate_literal(key, value, annotation, func.__name__)
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
 # Store this here for performance.
 StateBases = get_base_class(StateVar)
 StateIterBases = get_base_class(StateIterVar)

+ 3 - 3
reflex/vars.py

@@ -1612,13 +1612,13 @@ class Var:
             if types.is_generic_alias(self._var_type)
             else self._var_type
         )
-
         wrapped_var = str(self)
+
         return (
             wrapped_var
             if not self._var_state
-            and issubclass(type_, dict)
-            or issubclass(type_, Style)
+            and types._issubclass(type_, dict)
+            or types._issubclass(type_, Style)
             else wrapped_var.strip("{}")
         )
 

+ 13 - 0
scripts/migrate_project_to_rx_chakra.py

@@ -0,0 +1,13 @@
+"""Migrate project to rx.chakra. I.e. switch usage of rx.<component> to rx.chakra.<component>."""
+
+import argparse
+
+if __name__ == "__main__":
+    # parse args just for the help message (-h, etc)
+    parser = argparse.ArgumentParser(
+        description="Migrate project to rx.chakra. I.e. switch usage of rx.<component> to rx.chakra.<component>."
+    )
+    args = parser.parse_args()
+    from reflex.utils.prerequisites import migrate_to_rx_chakra
+
+    migrate_to_rx_chakra()

+ 0 - 0
tests/components/core/__init__.py


+ 66 - 0
tests/components/core/test_colors.py

@@ -0,0 +1,66 @@
+import pytest
+
+import reflex as rx
+
+
+class ColorState(rx.State):
+    """Test color state."""
+
+    color: str = "mint"
+    shade: int = 4
+
+
+@pytest.mark.parametrize(
+    "color, expected",
+    [
+        (rx.color("mint"), "{`var(--mint-7)`}"),
+        (rx.color("mint", 3), "{`var(--mint-3)`}"),
+        (rx.color("mint", 3, True), "{`var(--mint-a3)`}"),
+        (
+            rx.color(ColorState.color, ColorState.shade),  # type: ignore
+            "{`var(--${state__color_state.color}-${state__color_state.shade})`}",
+        ),
+    ],
+)
+def test_color(color, expected):
+    assert str(color) == expected
+
+
+@pytest.mark.parametrize(
+    "cond_var, expected",
+    [
+        (
+            rx.cond(True, rx.color("mint"), rx.color("tomato", 5)),
+            "{isTrue(true) ? `var(--mint-7)` : `var(--tomato-5)`}",
+        ),
+        (
+            rx.cond(True, rx.color(ColorState.color), rx.color(ColorState.color, 5)),  # type: ignore
+            "{isTrue(true) ? `var(--${state__color_state.color}-7)` : `var(--${state__color_state.color}-5)`}",
+        ),
+        (
+            rx.match(
+                "condition",
+                ("first", rx.color("mint")),
+                ("second", rx.color("tomato", 5)),
+                rx.color(ColorState.color, 2),  # type: ignore
+            ),
+            "{(() => { switch (JSON.stringify(`condition`)) {case JSON.stringify(`first`):  return (`var(--mint-7)`);"
+            "  break;case JSON.stringify(`second`):  return (`var(--tomato-5)`);  break;default:  "
+            "return (`var(--${state__color_state.color}-2)`);  break;};})()}",
+        ),
+        (
+            rx.match(
+                "condition",
+                ("first", rx.color(ColorState.color)),  # type: ignore
+                ("second", rx.color(ColorState.color, 5)),  # type: ignore
+                rx.color(ColorState.color, 2),  # type: ignore
+            ),
+            "{(() => { switch (JSON.stringify(`condition`)) {case JSON.stringify(`first`):  "
+            "return (`var(--${state__color_state.color}-7)`);  break;case JSON.stringify(`second`):  "
+            "return (`var(--${state__color_state.color}-5)`);  break;default:  "
+            "return (`var(--${state__color_state.color}-2)`);  break;};})()}",
+        ),
+    ],
+)
+def test_color_with_conditionals(cond_var, expected):
+    assert str(cond_var) == expected

+ 3 - 3
tests/utils/test_serializers.py

@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Type
 import pytest
 
 from reflex.base import Base
-from reflex.components.core.colors import color
+from reflex.components.core.colors import Color
 from reflex.utils import serializers
 from reflex.vars import Var
 
@@ -170,8 +170,8 @@ class BaseSubclass(Base):
             [datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)],
             '["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"]',
         ),
-        (color("slate", shade=1), "var(--slate-1)"),
-        (color("orange", shade=1, alpha=True), "var(--orange-a1)"),
+        (Color(color="slate", shade=1), "var(--slate-1)"),
+        (Color(color="orange", shade=1, alpha=True), "var(--orange-a1)"),
     ],
 )
 def test_serialize(value: Any, expected: str):