Pārlūkot izejas kodu

make the match logic better

Khaleel Al-Adhami 4 mēneši atpakaļ
vecāks
revīzija
29fc4b020a

+ 34 - 19
reflex/components/core/match.py

@@ -1,6 +1,5 @@
 """rx.match."""
 
-import textwrap
 from typing import Any, cast
 
 from typing_extensions import Unpack
@@ -46,27 +45,43 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None
     Raises:
         MatchTypeError: If the return types of cases are different.
     """
-    first_case_return = match_cases[0][-1]
-    return_type = type(first_case_return)
 
-    if types._isinstance(first_case_return, BaseComponent):
-        return_type = BaseComponent
-    elif types._isinstance(first_case_return, Var):
-        return_type = Var
-
-    for index, case in enumerate(match_cases):
-        if not (
-            types._issubclass(type(case[-1]), return_type)
-            or (
-                isinstance(case[-1], Var)
-                and types.typehint_issubclass(case[-1]._var_type, return_type)
+    def is_component_or_component_var(obj: Any) -> bool:
+        return types._isinstance(obj, BaseComponent) or (
+            isinstance(obj, Var)
+            and (
+                types.safe_typehint_issubclass(obj._var_type, BaseComponent)
+                or types.safe_typehint_issubclass(obj._var_type, list[BaseComponent])
             )
-        ):
-            raise MatchTypeError(
-                f"Match cases should have the same return types. Case {index} with return "
-                f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`"
-                f" of type {(type(case[-1]) if not isinstance(case[-1], Var) else case[-1]._var_type)!r} is not {return_type}"
+        )
+
+    def type_of_return_type(obj: Any) -> Any:
+        if isinstance(obj, Var):
+            return obj._var_type
+        return type(obj)
+
+    return_types = [case[-1] for case in match_cases]
+
+    if any(
+        is_component_or_component_var(return_type) for return_type in return_types
+    ) and not all(
+        is_component_or_component_var(return_type) for return_type in return_types
+    ):
+        non_component_return_types = [
+            (type_of_return_type(return_type), i)
+            for i, return_type in enumerate(return_types)
+            if not is_component_or_component_var(return_type)
+        ]
+        raise MatchTypeError(
+            "Match cases should have the same return types. "
+            + "Expected return types to be of type Component or Var[Component]. "
+            + ". ".join(
+                [
+                    f"Return type of case {i} is {return_type}"
+                    for return_type, i in non_component_return_types
+                ]
             )
+        )
 
 
 def _create_match_var(

+ 4 - 5
tests/units/components/core/test_match.py

@@ -1,3 +1,4 @@
+import re
 from typing import Tuple
 
 import pytest
@@ -177,8 +178,7 @@ def test_match_case_tuple_elements(match_case):
                 (MatchState.num + 1, "black"),
                 rx.text("default value"),
             ),
-            "Match cases should have the same return types. Case 3 with return value `red` of type "
-            "<class 'str'> is not <class 'reflex.components.component.BaseComponent'>",
+            "Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 3 is <class 'str'>. Return type of case 4 is <class 'str'>. Return type of case 5 is <class 'str'>",
         ),
         (
             (
@@ -190,8 +190,7 @@ def test_match_case_tuple_elements(match_case):
                 ([1, 2], rx.text("third value")),
                 rx.text("default value"),
             ),
-            'Match cases should have the same return types. Case 3 with return value `<RadixThemesText as={"p"}> {"first value"} </RadixThemesText>` '
-            "of type <class 'reflex.components.radix.themes.typography.text.Text'> is not <class 'str'>",
+            "Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 0 is <class 'str'>. Return type of case 1 is <class 'str'>. Return type of case 2 is <class 'str'>",
         ),
     ],
 )
@@ -202,7 +201,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
         cases: The match cases.
         error_msg: Expected error message.
     """
-    with pytest.raises(MatchTypeError, match=error_msg):
+    with pytest.raises(MatchTypeError, match=re.escape(error_msg)):
         match(MatchState.value, *cases)  # pyright: ignore[reportCallIssue]