|
@@ -1,6 +1,5 @@
|
|
|
"""rx.match."""
|
|
|
|
|
|
-import textwrap
|
|
|
from typing import Any, cast
|
|
|
|
|
|
from typing_extensions import Unpack
|
|
@@ -46,27 +45,43 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None
|
|
|
Raises:
|
|
|
MatchTypeError: If the return types of cases are different.
|
|
|
"""
|
|
|
- first_case_return = match_cases[0][-1]
|
|
|
- return_type = type(first_case_return)
|
|
|
|
|
|
- if types._isinstance(first_case_return, BaseComponent):
|
|
|
- return_type = BaseComponent
|
|
|
- elif types._isinstance(first_case_return, Var):
|
|
|
- return_type = Var
|
|
|
-
|
|
|
- for index, case in enumerate(match_cases):
|
|
|
- if not (
|
|
|
- types._issubclass(type(case[-1]), return_type)
|
|
|
- or (
|
|
|
- isinstance(case[-1], Var)
|
|
|
- and types.typehint_issubclass(case[-1]._var_type, return_type)
|
|
|
+ def is_component_or_component_var(obj: Any) -> bool:
|
|
|
+ return types._isinstance(obj, BaseComponent) or (
|
|
|
+ isinstance(obj, Var)
|
|
|
+ and (
|
|
|
+ types.safe_typehint_issubclass(obj._var_type, BaseComponent)
|
|
|
+ or types.safe_typehint_issubclass(obj._var_type, list[BaseComponent])
|
|
|
)
|
|
|
- ):
|
|
|
- raise MatchTypeError(
|
|
|
- f"Match cases should have the same return types. Case {index} with return "
|
|
|
- f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`"
|
|
|
- f" of type {(type(case[-1]) if not isinstance(case[-1], Var) else case[-1]._var_type)!r} is not {return_type}"
|
|
|
+ )
|
|
|
+
|
|
|
+ def type_of_return_type(obj: Any) -> Any:
|
|
|
+ if isinstance(obj, Var):
|
|
|
+ return obj._var_type
|
|
|
+ return type(obj)
|
|
|
+
|
|
|
+ return_types = [case[-1] for case in match_cases]
|
|
|
+
|
|
|
+ if any(
|
|
|
+ is_component_or_component_var(return_type) for return_type in return_types
|
|
|
+ ) and not all(
|
|
|
+ is_component_or_component_var(return_type) for return_type in return_types
|
|
|
+ ):
|
|
|
+ non_component_return_types = [
|
|
|
+ (type_of_return_type(return_type), i)
|
|
|
+ for i, return_type in enumerate(return_types)
|
|
|
+ if not is_component_or_component_var(return_type)
|
|
|
+ ]
|
|
|
+ raise MatchTypeError(
|
|
|
+ "Match cases should have the same return types. "
|
|
|
+ + "Expected return types to be of type Component or Var[Component]. "
|
|
|
+ + ". ".join(
|
|
|
+ [
|
|
|
+ f"Return type of case {i} is {return_type}"
|
|
|
+ for return_type, i in non_component_return_types
|
|
|
+ ]
|
|
|
)
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def _create_match_var(
|