ソースを参照

refactor match code

Khaleel Al-Adhami 3 ヶ月 前
コミット
c74313992f
1 ファイル変更24 行追加26 行削除
  1. 24 26
      reflex/components/core/match.py

+ 24 - 26
reflex/components/core/match.py

@@ -5,7 +5,7 @@ from typing import Any, Union, cast
 from typing_extensions import Unpack
 from typing_extensions import Unpack
 
 
 from reflex.components.base import Fragment
 from reflex.components.base import Fragment
-from reflex.components.component import BaseComponent, Component
+from reflex.components.component import BaseComponent
 from reflex.utils import types
 from reflex.utils import types
 from reflex.utils.exceptions import MatchTypeError
 from reflex.utils.exceptions import MatchTypeError
 from reflex.vars.base import VAR_TYPE, Var
 from reflex.vars.base import VAR_TYPE, Var
@@ -36,11 +36,14 @@ def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
             )
             )
 
 
 
 
-def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None:
+def _validate_return_types(*return_values: Any) -> bool:
     """Validate that match cases have the same return types.
     """Validate that match cases have the same return types.
 
 
     Args:
     Args:
-        match_cases: The match cases.
+        return_values: The return values of the match cases.
+
+    Returns:
+        True if all cases have the same return types.
 
 
     Raises:
     Raises:
         MatchTypeError: If the return types of cases are different.
         MatchTypeError: If the return types of cases are different.
@@ -54,22 +57,20 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None
             )
             )
         )
         )
 
 
-    def type_of_return_type(obj: Any) -> Any:
+    def type_of_return_value(obj: Any) -> Any:
         if isinstance(obj, Var):
         if isinstance(obj, Var):
             return obj._var_type
             return obj._var_type
         return type(obj)
         return type(obj)
 
 
-    return_types = [case[-1] for case in match_cases]
+    is_return_type_component = [
+        is_component_or_component_var(return_type) for return_type in return_values
+    ]
 
 
-    if any(
-        is_component_or_component_var(return_type) for return_type in return_types
-    ) and not all(
-        is_component_or_component_var(return_type) for return_type in return_types
-    ):
+    if any(is_return_type_component) and not all(is_return_type_component):
         non_component_return_types = [
         non_component_return_types = [
-            (type_of_return_type(return_type), i)
-            for i, return_type in enumerate(return_types)
-            if not is_component_or_component_var(return_type)
+            (type_of_return_value(return_value), i)
+            for i, return_value in enumerate(return_values)
+            if not is_return_type_component[i]
         ]
         ]
         raise MatchTypeError(
         raise MatchTypeError(
             "Match cases should have the same return types. "
             "Match cases should have the same return types. "
@@ -82,6 +83,8 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None
             )
             )
         )
         )
 
 
+    return all(is_return_type_component)
+
 
 
 def _create_match_var(
 def _create_match_var(
     match_cond_var: Var,
     match_cond_var: Var,
@@ -119,7 +122,7 @@ def match(
     Raises:
     Raises:
         ValueError: If the default case is not the last case or the tuple elements are less than 2.
         ValueError: If the default case is not the last case or the tuple elements are less than 2.
     """
     """
-    default = None
+    default = types.Unset()
 
 
     if len([case for case in cases if not isinstance(case, tuple)]) > 1:
     if len([case for case in cases if not isinstance(case, tuple)]) > 1:
         raise ValueError("rx.match can only have one default case.")
         raise ValueError("rx.match can only have one default case.")
@@ -136,22 +139,17 @@ def match(
 
 
     _process_match_cases(actual_cases)
     _process_match_cases(actual_cases)
 
 
-    _validate_return_types(actual_cases)
+    is_component_match = _validate_return_types(
+        *[case[-1] for case in actual_cases],
+        *([default] if not isinstance(default, types.Unset) else []),
+    )
 
 
-    if default is None and any(
-        not (
-            isinstance((return_type := case[-1]), Component)
-            or (
-                isinstance(return_type, Var)
-                and types.typehint_issubclass(return_type._var_type, Component)
-            )
-        )
-        for case in actual_cases
-    ):
+    if isinstance(default, types.Unset) and not is_component_match:
         raise ValueError(
         raise ValueError(
             "For cases with return types as Vars, a default case must be provided"
             "For cases with return types as Vars, a default case must be provided"
         )
         )
-    elif default is None:
+
+    if isinstance(default, types.Unset):
         default = Fragment.create()
         default = Fragment.create()
 
 
     default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
     default = cast(Var[VAR_TYPE] | VAR_TYPE, default)