Kaynağa Gözat

rx.match component (#2318)

* initial commit

* add more tests

* refactor match jinja template

* add docstrings

* cleanup

* possible fix for pyright

* fix conflicts

* fix conflicts again

* comments

* fixed bug from review

* fix tests

* address PR comment

* fix tests

* type error improvement

* formatting

* darglint fix

* more tests

* stringify switch condition and cases as js doesnt support complex types(lists and dicts) in switch cases.

* Update reflex/vars.py

Co-authored-by: Masen Furer <m_github@0x26.net>

* change usages

* Precommit fix

---------

Co-authored-by: Alek Petuskey <alek@pynecone.io>
Co-authored-by: Masen Furer <m_github@0x26.net>
Co-authored-by: Alek Petuskey <alekpetuskey@aleks-mbp.lan>
Elijah Ahianyo 1 yıl önce
ebeveyn
işleme
abfc099779

+ 24 - 0
reflex/.templates/jinja/web/pages/utils.js.jinja2

@@ -8,6 +8,8 @@
     {{- component }}
   {%- elif "iterable" in component %}
     {{- render_iterable_tag(component) }}
+  {%- elif component.name == "match"%}
+    {{- render_match_tag(component) }}
   {%- elif "cond" in component %}
     {{- render_condition_tag(component) }}
   {%- elif component.children|length %}
@@ -77,6 +79,28 @@
 {% if props|length %} {{ props|join(" ") }}{% endif %}
 {% endmacro %}
 
+{# Rendering Match component. #}
+{# Args: #}
+{#     component: component dictionary #}
+{% macro render_match_tag(component) %}
+{
+    (() => {
+        switch (JSON.stringify({{ component.cond._var_full_name }})) {
+        {% for case in component.match_cases %}
+            {% for condition in case[:-1] %}
+                case JSON.stringify({{ condition._var_name_unwrapped }}):
+            {% endfor %}
+                return {{ case[-1] }};
+                break;
+        {% endfor %}
+            default:
+                return {{ component.default }};
+                break;
+        }
+    })()
+  }
+{%- endmacro %}
+
 
 {# Rendering content with args. #}
 {# Args: #}

+ 1 - 0
reflex/__init__.py

@@ -114,6 +114,7 @@ _ALL_COMPONENTS = [
     "List",
     "ListItem",
     "Markdown",
+    "Match",
     "Menu",
     "MenuButton",
     "MenuDivider",

+ 2 - 0
reflex/__init__.pyi

@@ -107,6 +107,7 @@ from reflex.components import LinkOverlay as LinkOverlay
 from reflex.components import List as List
 from reflex.components import ListItem as ListItem
 from reflex.components import Markdown as Markdown
+from reflex.components import Match as Match
 from reflex.components import Menu as Menu
 from reflex.components import MenuButton as MenuButton
 from reflex.components import MenuDivider as MenuDivider
@@ -317,6 +318,7 @@ from reflex.components import link_overlay as link_overlay
 from reflex.components import list as list
 from reflex.components import list_item as list_item
 from reflex.components import markdown as markdown
+from reflex.components import match as match
 from reflex.components import menu as menu
 from reflex.components import menu_button as menu_button
 from reflex.components import menu_divider as menu_divider

+ 4 - 1
reflex/app.py

@@ -64,7 +64,7 @@ from reflex.state import (
     StateManager,
     StateUpdate,
 )
-from reflex.utils import console, format, prerequisites, types
+from reflex.utils import console, exceptions, format, prerequisites, types
 from reflex.utils.imports import ImportVar
 
 # Define custom types.
@@ -344,9 +344,12 @@ class App(Base):
 
         Raises:
             TypeError: When an invalid component function is passed.
+            exceptions.MatchTypeError: If the return types of match cases in rx.match are different.
         """
         try:
             return component if isinstance(component, Component) else component()
+        except exceptions.MatchTypeError:
+            raise
         except TypeError as e:
             message = str(e)
             if "BaseVar" in message or "ComputedVar" in message:

+ 2 - 0
reflex/components/core/__init__.py

@@ -4,6 +4,7 @@ from .banner import ConnectionBanner, ConnectionModal
 from .cond import Cond, cond
 from .debounce import DebounceInput
 from .foreach import Foreach
+from .match import Match
 from .responsive import (
     desktop_only,
     mobile_and_tablet,
@@ -17,4 +18,5 @@ connection_banner = ConnectionBanner.create
 connection_modal = ConnectionModal.create
 debounce_input = DebounceInput.create
 foreach = Foreach.create
+match = Match.create
 upload = Upload.create

+ 257 - 0
reflex/components/core/match.py

@@ -0,0 +1,257 @@
+"""rx.match."""
+import textwrap
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from reflex.components.base import Fragment
+from reflex.components.component import BaseComponent, Component, MemoizationLeaf
+from reflex.components.tags import MatchTag, Tag
+from reflex.utils import format, imports, types
+from reflex.utils.exceptions import MatchTypeError
+from reflex.vars import BaseVar, Var, VarData
+
+
+class Match(MemoizationLeaf):
+    """Match cases based on a condition."""
+
+    # The condition to determine which case to match.
+    cond: Var[Any]
+
+    # The list of match cases to be matched.
+    match_cases: List[Any] = []
+
+    # The catchall case to match.
+    default: Any
+
+    @classmethod
+    def create(cls, cond: Any, *cases) -> Union[Component, BaseVar]:
+        """Create a Match Component.
+
+        Args:
+            cond: The condition to determine which case to match.
+            cases: This list of cases to match.
+
+        Returns:
+            The match component.
+
+        Raises:
+            ValueError: When a default case is not provided for cases with Var return types.
+        """
+        match_cond_var = cls._create_condition_var(cond)
+        cases, default = cls._process_cases(list(cases))
+        match_cases = cls._process_match_cases(cases)
+
+        cls._validate_return_types(match_cases)
+
+        if default is None and types._issubclass(type(match_cases[0][-1]), BaseVar):
+            raise ValueError(
+                "For cases with return types as Vars, a default case must be provided"
+            )
+
+        return cls._create_match_cond_var_or_component(
+            match_cond_var, match_cases, default
+        )
+
+    @classmethod
+    def _create_condition_var(cls, cond: Any) -> BaseVar:
+        """Convert the condition to a Var.
+
+        Args:
+            cond: The condition.
+
+        Returns:
+            The condition as a base var
+
+        Raises:
+            ValueError: If the condition is not provided.
+        """
+        match_cond_var = Var.create(cond)
+        if match_cond_var is None:
+            raise ValueError("The condition must be set")
+        return match_cond_var  # type: ignore
+
+    @classmethod
+    def _process_cases(
+        cls, cases: List
+    ) -> Tuple[List, Optional[Union[BaseVar, BaseComponent]]]:
+        """Process the list of match cases and the catchall default case.
+
+        Args:
+            cases: The list of match cases.
+
+        Returns:
+            The default case and the list of match case tuples.
+
+        Raises:
+            ValueError: If there are multiple default cases.
+        """
+        default = None
+
+        if len([case for case in cases if not isinstance(case, tuple)]) > 1:
+            raise ValueError("rx.match can only have one default case.")
+
+        # Get the default case which should be the last non-tuple arg
+        if not isinstance(cases[-1], tuple):
+            default = cases.pop()
+            default = (
+                Var.create(default, _var_is_string=type(default) is str)
+                if not isinstance(default, BaseComponent)
+                else default
+            )
+
+        return cases, default  # type: ignore
+
+    @classmethod
+    def _process_match_cases(cls, cases: List) -> List[List[BaseVar]]:
+        """Process the individual match cases.
+
+        Args:
+            cases: The match cases.
+
+        Returns:
+            The processed match cases.
+
+        Raises:
+            ValueError: If the default case is not the last case or the tuple elements are less than 2.
+        """
+        match_cases = []
+        for case in cases:
+            if not isinstance(case, tuple):
+                raise ValueError(
+                    "rx.match should have tuples of cases and a default case as the last argument."
+                )
+            # There should be at least two elements in a case tuple(a condition and return value)
+            if len(case) < 2:
+                raise ValueError(
+                    "A case tuple should have at least a match case element and a return value."
+                )
+
+            case_list = []
+            for element in case:
+                # convert all non component element to vars.
+                el = (
+                    Var.create(element, _var_is_string=type(element) is str)
+                    if not isinstance(element, BaseComponent)
+                    else element
+                )
+                if not isinstance(el, (BaseVar, BaseComponent)):
+                    raise ValueError("Case element must be a var or component")
+                case_list.append(el)
+
+            match_cases.append(case_list)
+
+        return match_cases
+
+    @classmethod
+    def _validate_return_types(cls, match_cases: List[List[BaseVar]]) -> None:
+        """Validate that match cases have the same return types.
+
+        Args:
+            match_cases: The match cases.
+
+        Raises:
+            MatchTypeError: If the return types of cases are different.
+        """
+        first_case_return = match_cases[0][-1]
+        return_type = type(first_case_return)
+
+        if types._isinstance(first_case_return, BaseComponent):
+            return_type = BaseComponent
+        elif types._isinstance(first_case_return, BaseVar):
+            return_type = BaseVar
+
+        for index, case in enumerate(match_cases):
+            if not types._issubclass(type(case[-1]), return_type):
+                raise MatchTypeError(
+                    f"Match cases should have the same return types. Case {index} with return "
+                    f"value `{case[-1]._var_name if isinstance(case[-1], BaseVar) else textwrap.shorten(str(case[-1]), width=250)}`"
+                    f" of type {type(case[-1])!r} is not {return_type}"
+                )
+
+    @classmethod
+    def _create_match_cond_var_or_component(
+        cls,
+        match_cond_var: Var,
+        match_cases: List[List[BaseVar]],
+        default: Optional[Union[BaseVar, BaseComponent]],
+    ) -> Union[Component, BaseVar]:
+        """Create and return the match condition var or component.
+
+        Args:
+            match_cond_var: The match condition.
+            match_cases: The list of match cases.
+            default: The default case.
+
+        Returns:
+            The match component wrapped in a fragment or the match var.
+
+        Raises:
+            ValueError: If the return types are not vars when creating a match var for Var types.
+        """
+        if default is None and types._issubclass(
+            type(match_cases[0][-1]), BaseComponent
+        ):
+            default = Fragment.create()
+
+        if types._issubclass(type(match_cases[0][-1]), BaseComponent):
+            return Fragment.create(
+                cls(
+                    cond=match_cond_var,
+                    match_cases=match_cases,
+                    default=default,
+                )
+            )
+
+        # Validate the match cases (as well as the default case) to have Var return types.
+        if any(
+            case for case in match_cases if not types._isinstance(case[-1], BaseVar)
+        ) or not types._isinstance(default, BaseVar):
+            raise ValueError("Return types of match cases should be Vars.")
+
+        # match cases and default should all be Vars at this point.
+        # Retrieve var data of every var in the match cases and default.
+        var_data = [
+            *[el._var_data for case in match_cases for el in case],
+            default._var_data,  # type: ignore
+        ]
+
+        return match_cond_var._replace(
+            _var_name=format.format_match(
+                cond=match_cond_var._var_full_name,
+                match_cases=match_cases,  # type: ignore
+                default=default,  # type: ignore
+            ),
+            _var_type=default._var_type,  # type: ignore
+            _var_is_local=False,
+            _var_full_name_needs_state_prefix=False,
+            merge_var_data=VarData.merge(*var_data),
+        )
+
+    def _render(self) -> Tag:
+        return MatchTag(
+            cond=self.cond, match_cases=self.match_cases, default=self.default
+        )
+
+    def render(self) -> Dict:
+        """Render the component.
+
+        Returns:
+            The dictionary for template of component.
+        """
+        tag = self._render()
+        tag.name = "match"
+        return dict(tag)
+
+    def _get_imports(self):
+        merged_imports = super()._get_imports()
+        # Obtain the imports of all components the in match case.
+        for case in self.match_cases:
+            if isinstance(case[-1], BaseComponent):
+                merged_imports = imports.merge_imports(
+                    merged_imports, case[-1].get_imports()
+                )
+        # Get the import of the default case component.
+        if isinstance(self.default, BaseComponent):
+            merged_imports = imports.merge_imports(
+                merged_imports, self.default.get_imports()
+            )
+        return merged_imports

+ 1 - 0
reflex/components/tags/__init__.py

@@ -2,4 +2,5 @@
 
 from .cond_tag import CondTag
 from .iter_tag import IterTag
+from .match_tag import MatchTag
 from .tag import Tag

+ 19 - 0
reflex/components/tags/match_tag.py

@@ -0,0 +1,19 @@
+"""Tag to conditionally match cases."""
+
+from typing import Any, List
+
+from reflex.components.tags.tag import Tag
+from reflex.vars import Var
+
+
+class MatchTag(Tag):
+    """A match tag."""
+
+    # The condition to determine which case to match.
+    cond: Var[Any]
+
+    # The list of match cases to be matched.
+    match_cases: List[Any]
+
+    # The catchall case to match.
+    default: Any

+ 6 - 0
reflex/utils/exceptions.py

@@ -13,3 +13,9 @@ class ImmutableStateError(AttributeError):
 
 class LockExpiredError(Exception):
     """Raised when the state lock expires while an event is being processed."""
+
+
+class MatchTypeError(TypeError):
+    """Raised when the return types of match cases are different."""
+
+    pass

+ 36 - 1
reflex/utils/format.py

@@ -7,7 +7,7 @@ import json
 import os
 import re
 import sys
-from typing import TYPE_CHECKING, Any, Union
+from typing import TYPE_CHECKING, Any, List, Union
 
 from reflex import constants
 from reflex.utils import exceptions, serializers, types
@@ -272,6 +272,41 @@ def format_cond(
     return wrap(f"{cond} ? {true_value} : {false_value}", "{")
 
 
+def format_match(cond: str | Var, match_cases: List[BaseVar], default: Var) -> str:
+    """Format a match expression whose return type is a Var.
+
+    Args:
+        cond: The condition.
+        match_cases: The list of cases to match.
+        default: The default case.
+
+    Returns:
+        The formatted match expression
+
+    """
+    switch_code = f"(() => {{ switch (JSON.stringify({cond})) {{"
+
+    for case in match_cases:
+        conditions = case[:-1]
+        return_value = case[-1]
+
+        case_conditions = " ".join(
+            [
+                f"case JSON.stringify({condition._var_name_unwrapped}):"
+                for condition in conditions
+            ]
+        )
+        case_code = (
+            f"{case_conditions}  return ({return_value._var_name_unwrapped});  break;"
+        )
+        switch_code += case_code
+
+    switch_code += f"default:  return ({default._var_name_unwrapped});  break;"
+    switch_code += "};})()"
+
+    return switch_code
+
+
 def format_prop(
     prop: Union[Var, EventChain, ComponentStyle, str],
 ) -> Union[int, float, str]:

+ 20 - 0
reflex/vars.py

@@ -1534,6 +1534,26 @@ class Var:
         """
         return self._var_data.state if self._var_data else ""
 
+    @property
+    def _var_name_unwrapped(self) -> str:
+        """Get the var str without wrapping in curly braces.
+
+        Returns:
+            The str var without the wrapped curly braces
+        """
+        type_ = (
+            get_origin(self._var_type)
+            if types.is_generic_alias(self._var_type)
+            else self._var_type
+        )
+
+        wrapped_var = str(self)
+        return (
+            wrapped_var
+            if not self._var_state and issubclass(type_, dict)
+            else wrapped_var.strip("{}")
+        )
+
 
 # Allow automatic serialization of Var within JSON structures
 serializers.serializer(_encode_var)

+ 1 - 0
scripts/pyi_generator.py

@@ -30,6 +30,7 @@ EXCLUDED_FILES = [
     "bare.py",
     "foreach.py",
     "cond.py",
+    "match.py",
     "multiselect.py",
     "literals.py",
 ]

+ 306 - 0
tests/components/layout/test_match.py

@@ -0,0 +1,306 @@
+from typing import Tuple
+
+import pytest
+
+import reflex as rx
+from reflex.components.core.match import Match
+from reflex.state import BaseState
+from reflex.utils.exceptions import MatchTypeError
+from reflex.vars import BaseVar
+
+
+class MatchState(BaseState):
+    """A test state."""
+
+    value: int = 0
+    num: int = 5
+    string: str = "random string"
+
+
+def test_match_components():
+    """Test matching cases with return values as components."""
+    match_case_tuples = (
+        (1, rx.text("first value")),
+        (2, 3, rx.text("second value")),
+        ([1, 2], rx.text("third value")),
+        ("random", rx.text("fourth value")),
+        ({"foo": "bar"}, rx.text("fifth value")),
+        (MatchState.num + 1, rx.text("sixth value")),
+        rx.text("default value"),
+    )
+    match_comp = Match.create(MatchState.value, *match_case_tuples)
+    match_dict = match_comp.render()  # type: ignore
+    assert match_dict["name"] == "Fragment"
+
+    [match_child] = match_dict["children"]
+
+    assert match_child["name"] == "match"
+    assert str(match_child["cond"]) == "{match_state.value}"
+
+    match_cases = match_child["match_cases"]
+    assert len(match_cases) == 6
+
+    assert match_cases[0][0]._var_name == "1"
+    assert match_cases[0][0]._var_type == int
+    first_return_value_render = match_cases[0][1].render()
+    assert first_return_value_render["name"] == "Text"
+    assert first_return_value_render["children"][0]["contents"] == "{`first value`}"
+
+    assert match_cases[1][0]._var_name == "2"
+    assert match_cases[1][0]._var_type == int
+    assert match_cases[1][1]._var_name == "3"
+    assert match_cases[1][1]._var_type == int
+    second_return_value_render = match_cases[1][2].render()
+    assert second_return_value_render["name"] == "Text"
+    assert second_return_value_render["children"][0]["contents"] == "{`second value`}"
+
+    assert match_cases[2][0]._var_name == "[1, 2]"
+    assert match_cases[2][0]._var_type == list
+    third_return_value_render = match_cases[2][1].render()
+    assert third_return_value_render["name"] == "Text"
+    assert third_return_value_render["children"][0]["contents"] == "{`third value`}"
+
+    assert match_cases[3][0]._var_name == "random"
+    assert match_cases[3][0]._var_type == str
+    fourth_return_value_render = match_cases[3][1].render()
+    assert fourth_return_value_render["name"] == "Text"
+    assert fourth_return_value_render["children"][0]["contents"] == "{`fourth value`}"
+
+    assert match_cases[4][0]._var_name == '{"foo": "bar"}'
+    assert match_cases[4][0]._var_type == dict
+    fifth_return_value_render = match_cases[4][1].render()
+    assert fifth_return_value_render["name"] == "Text"
+    assert fifth_return_value_render["children"][0]["contents"] == "{`fifth value`}"
+
+    assert match_cases[5][0]._var_name == "(match_state.num + 1)"
+    assert match_cases[5][0]._var_type == int
+    fifth_return_value_render = match_cases[5][1].render()
+    assert fifth_return_value_render["name"] == "Text"
+    assert fifth_return_value_render["children"][0]["contents"] == "{`sixth value`}"
+
+    default = match_child["default"].render()
+
+    assert default["name"] == "Text"
+    assert default["children"][0]["contents"] == "{`default value`}"
+
+
+@pytest.mark.parametrize(
+    "cases, expected",
+    [
+        (
+            (
+                (1, "first"),
+                (2, 3, "second value"),
+                ([1, 2], "third-value"),
+                ("random", "fourth_value"),
+                ({"foo": "bar"}, "fifth value"),
+                (MatchState.num + 1, "sixth value"),
+                (f"{MatchState.value} - string", MatchState.string),
+                (MatchState.string, f"{MatchState.value} - string"),
+                "default value",
+            ),
+            "(() => { switch (JSON.stringify(match_state.value)) {case JSON.stringify(1):  return (`first`);  break;case JSON.stringify(2): case JSON.stringify(3):  return "
+            "(`second value`);  break;case JSON.stringify([1, 2]):  return (`third-value`);  break;case JSON.stringify(`random`):  "
+            'return (`fourth_value`);  break;case JSON.stringify({"foo": "bar"}):  return (`fifth value`);  '
+            "break;case JSON.stringify((match_state.num + 1)):  return (`sixth value`);  break;case JSON.stringify(`${match_state.value} - string`):  "
+            "return (match_state.string);  break;case JSON.stringify(match_state.string):  return (`${match_state.value} - string`);  break;default:  "
+            "return (`default value`);  break;};})()",
+        ),
+        (
+            (
+                (1, "first"),
+                (2, 3, "second value"),
+                ([1, 2], "third-value"),
+                ("random", "fourth_value"),
+                ({"foo": "bar"}, "fifth value"),
+                (MatchState.num + 1, "sixth value"),
+                (f"{MatchState.value} - string", MatchState.string),
+                (MatchState.string, f"{MatchState.value} - string"),
+                MatchState.string,
+            ),
+            "(() => { switch (JSON.stringify(match_state.value)) {case JSON.stringify(1):  return (`first`);  break;case JSON.stringify(2): case JSON.stringify(3):  return "
+            "(`second value`);  break;case JSON.stringify([1, 2]):  return (`third-value`);  break;case JSON.stringify(`random`):  "
+            'return (`fourth_value`);  break;case JSON.stringify({"foo": "bar"}):  return (`fifth value`);  '
+            "break;case JSON.stringify((match_state.num + 1)):  return (`sixth value`);  break;case JSON.stringify(`${match_state.value} - string`):  "
+            "return (match_state.string);  break;case JSON.stringify(match_state.string):  return (`${match_state.value} - string`);  break;default:  "
+            "return (match_state.string);  break;};})()",
+        ),
+    ],
+)
+def test_match_vars(cases, expected):
+    """Test matching cases with return values as Vars.
+
+    Args:
+        cases: The match cases.
+        expected: The expected var full name.
+    """
+    match_comp = Match.create(MatchState.value, *cases)
+    assert isinstance(match_comp, BaseVar)
+    assert match_comp._var_full_name == expected
+
+
+def test_match_on_component_without_default():
+    """Test that matching cases with return values as components returns a Fragment
+    as the default case if not provided.
+    """
+    match_case_tuples = (
+        (1, rx.text("first value")),
+        (2, 3, rx.text("second value")),
+    )
+
+    match_comp = Match.create(MatchState.value, *match_case_tuples)
+    default = match_comp.render()["children"][0]["default"]  # type: ignore
+
+    assert isinstance(default, rx.Fragment)
+
+
+def test_match_on_var_no_default():
+    """Test that an error is thrown when cases with return Values as Var do not have a default case."""
+    match_case_tuples = (
+        (1, "red"),
+        (2, 3, "blue"),
+        ([1, 2], "green"),
+    )
+
+    with pytest.raises(
+        ValueError,
+        match="For cases with return types as Vars, a default case must be provided",
+    ):
+        Match.create(MatchState.value, *match_case_tuples)
+
+
+@pytest.mark.parametrize(
+    "match_case",
+    [
+        (
+            (1, "red"),
+            (2, 3, "blue"),
+            "black",
+            ([1, 2], "green"),
+        ),
+        (
+            (1, rx.text("first value")),
+            (2, 3, rx.text("second value")),
+            ([1, 2], rx.text("third value")),
+            rx.text("default value"),
+            ("random", rx.text("fourth value")),
+            ({"foo": "bar"}, rx.text("fifth value")),
+            (MatchState.num + 1, rx.text("sixth value")),
+        ),
+    ],
+)
+def test_match_default_not_last_arg(match_case):
+    """Test that an error is thrown when the default case is not the last arg.
+
+    Args:
+        match_case: The cases to match.
+    """
+    with pytest.raises(
+        ValueError,
+        match="rx.match should have tuples of cases and a default case as the last argument.",
+    ):
+        Match.create(MatchState.value, *match_case)
+
+
+@pytest.mark.parametrize(
+    "match_case",
+    [
+        (
+            (1, "red"),
+            (2, 3, "blue"),
+            ("green",),
+            "black",
+        ),
+        (
+            (1, rx.text("first value")),
+            (2, 3, rx.text("second value")),
+            ([1, 2],),
+            rx.text("default value"),
+        ),
+    ],
+)
+def test_match_case_tuple_elements(match_case):
+    """Test that a match has at least 2 elements(a condition and a return value).
+
+    Args:
+        match_case: The cases to match.
+    """
+    with pytest.raises(
+        ValueError,
+        match="A case tuple should have at least a match case element and a return value.",
+    ):
+        Match.create(MatchState.value, *match_case)
+
+
+@pytest.mark.parametrize(
+    "cases, error_msg",
+    [
+        (
+            (
+                (1, rx.text("first value")),
+                (2, 3, rx.text("second value")),
+                ([1, 2], rx.text("third value")),
+                ("random", "red"),
+                ({"foo": "bar"}, "green"),
+                (MatchState.num + 1, "black"),
+                rx.text("default value"),
+            ),
+            "Match cases should have the same return types. Case 3 with return value `red` of type "
+            "<class 'reflex.vars.BaseVar'> is not <class 'reflex.components.component.BaseComponent'>",
+        ),
+        (
+            (
+                ("random", "red"),
+                ({"foo": "bar"}, "green"),
+                (MatchState.num + 1, "black"),
+                (1, rx.text("first value")),
+                (2, 3, rx.text("second value")),
+                ([1, 2], rx.text("third value")),
+                rx.text("default value"),
+            ),
+            "Match cases should have the same return types. Case 3 with return value `<Text> {`first value`} </Text>` "
+            "of type <class 'reflex.components.chakra.typography.text.Text'> is not <class 'reflex.vars.BaseVar'>",
+        ),
+    ],
+)
+def test_match_different_return_types(cases: Tuple, error_msg: str):
+    """Test that an error is thrown when the return values are of different types.
+
+    Args:
+        cases: The match cases.
+        error_msg: Expected error message.
+    """
+    with pytest.raises(MatchTypeError, match=error_msg):
+        Match.create(MatchState.value, *cases)
+
+
+@pytest.mark.parametrize(
+    "match_case",
+    [
+        (
+            (1, "red"),
+            (2, 3, "blue"),
+            ([1, 2], "green"),
+            "black",
+            "white",
+        ),
+        (
+            (1, rx.text("first value")),
+            (2, 3, rx.text("second value")),
+            ([1, 2], rx.text("third value")),
+            ("random", rx.text("fourth value")),
+            ({"foo": "bar"}, rx.text("fifth value")),
+            (MatchState.num + 1, rx.text("sixth value")),
+            rx.text("default value"),
+            rx.text("another default value"),
+        ),
+    ],
+)
+def test_match_multiple_default_cases(match_case):
+    """Test that there is only one default case.
+
+    Args:
+        match_case: the cases to match.
+    """
+    with pytest.raises(ValueError, match="rx.match can only have one default case."):
+        Match.create(MatchState.value, *match_case)

+ 1 - 0
tests/components/test_component.py

@@ -896,6 +896,7 @@ def test_instantiate_all_components():
         "FormControl",
         "Html",
         "Icon",
+        "Match",
         "Markdown",
         "MultiSelect",
         "Option",

+ 26 - 0
tests/test_var.py

@@ -24,6 +24,13 @@ test_vars = [
 ]
 
 
+class ATestState(BaseState):
+    """Test state."""
+
+    value: str
+    dict_val: Dict[str, List] = {}
+
+
 @pytest.fixture
 def TestObj():
     class TestObj(Base):
@@ -1137,3 +1144,22 @@ def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: List
 
         with pytest.raises(TypeError):
             operand1_var.operation(op=operator, other=operand2_var, flip=True)
+
+
+@pytest.mark.parametrize(
+    "var, expected",
+    [
+        (Var.create("string_value", _var_is_string=True), "`string_value`"),
+        (Var.create(1), "1"),
+        (Var.create([1, 2, 3]), "[1, 2, 3]"),
+        (Var.create({"foo": "bar"}), '{"foo": "bar"}'),
+        (Var.create(ATestState.value, _var_is_string=True), "a_test_state.value"),
+        (
+            Var.create(f"{ATestState.value} string", _var_is_string=True),
+            "`${a_test_state.value} string`",
+        ),
+        (Var.create(ATestState.dict_val), "a_test_state.dict_val"),
+    ],
+)
+def test_var_name_unwrapped(var, expected):
+    assert var._var_name_unwrapped == expected

+ 37 - 1
tests/utils/test_format.py

@@ -1,5 +1,5 @@
 import datetime
-from typing import Any
+from typing import Any, List
 
 import pytest
 
@@ -294,6 +294,42 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected
     assert format.format_cond(condition, true_value, false_value) == expected
 
 
+@pytest.mark.parametrize(
+    "condition, match_cases, default,expected",
+    [
+        (
+            "state__state.value",
+            [
+                [Var.create(1), Var.create("red", _var_is_string=True)],
+                [Var.create(2), Var.create(3), Var.create("blue", _var_is_string=True)],
+                [TestState.mapping, TestState.num1],
+                [
+                    Var.create(f"{TestState.map_key}-key", _var_is_string=True),
+                    Var.create("return-key", _var_is_string=True),
+                ],
+            ],
+            Var.create("yellow", _var_is_string=True),
+            "(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1):  return (`red`);  break;case JSON.stringify(2): case JSON.stringify(3):  "
+            "return (`blue`);  break;case JSON.stringify(test_state.mapping):  return "
+            "(test_state.num1);  break;case JSON.stringify(`${test_state.map_key}-key`):  return (`return-key`);"
+            "  break;default:  return (`yellow`);  break;};})()",
+        )
+    ],
+)
+def test_format_match(
+    condition: str, match_cases: List[BaseVar], default: BaseVar, expected: str
+):
+    """Test formatting a match statement.
+
+    Args:
+        condition: The condition to match.
+        match_cases: List of match cases to be matched.
+        default: Catchall case for the match statement.
+        expected: The expected string output.
+    """
+    assert format.format_match(condition, match_cases, default) == expected
+
+
 @pytest.mark.parametrize(
     "prop,formatted",
     [