match.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. """rx.match."""
  2. from typing import Any, Union, cast
  3. from typing_extensions import Unpack
  4. from reflex.components.base import Fragment
  5. from reflex.components.component import BaseComponent
  6. from reflex.utils import types
  7. from reflex.utils.exceptions import MatchTypeError
  8. from reflex.vars.base import VAR_TYPE, Var
  9. from reflex.vars.number import MatchOperation
  10. CASE_TYPE = tuple[Unpack[tuple[Any, ...]], Var[VAR_TYPE] | VAR_TYPE]
  11. def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
  12. """Process the individual match cases.
  13. Args:
  14. cases: The match cases.
  15. Raises:
  16. ValueError: If the default case is not the last case or the tuple elements are less than 2.
  17. """
  18. for case in cases:
  19. if not isinstance(case, tuple):
  20. raise ValueError(
  21. "rx.match should have tuples of cases and a default case as the last argument."
  22. )
  23. # There should be at least two elements in a case tuple(a condition and return value)
  24. if len(case) < 2:
  25. raise ValueError(
  26. "A case tuple should have at least a match case element and a return value."
  27. )
  28. def _validate_return_types(*return_values: Any) -> bool:
  29. """Validate that match cases have the same return types.
  30. Args:
  31. return_values: The return values of the match cases.
  32. Returns:
  33. True if all cases have the same return types.
  34. Raises:
  35. MatchTypeError: If the return types of cases are different.
  36. """
  37. def is_component_or_component_var(obj: Any) -> bool:
  38. return types._isinstance(obj, BaseComponent) or (
  39. isinstance(obj, Var)
  40. and types.safe_typehint_issubclass(
  41. obj._var_type, Union[list[BaseComponent], BaseComponent]
  42. )
  43. )
  44. def type_of_return_value(obj: Any) -> Any:
  45. if isinstance(obj, Var):
  46. return obj._var_type
  47. return type(obj)
  48. is_return_type_component = [
  49. is_component_or_component_var(return_type) for return_type in return_values
  50. ]
  51. if any(is_return_type_component) and not all(is_return_type_component):
  52. non_component_return_types = [
  53. (type_of_return_value(return_value), i)
  54. for i, return_value in enumerate(return_values)
  55. if not is_return_type_component[i]
  56. ]
  57. raise MatchTypeError(
  58. "Match cases should have the same return types. "
  59. + "Expected return types to be of type Component or Var[Component]. "
  60. + ". ".join(
  61. [
  62. f"Return type of case {i} is {return_type}"
  63. for return_type, i in non_component_return_types
  64. ]
  65. )
  66. )
  67. return all(is_return_type_component)
  68. def _create_match_var(
  69. match_cond_var: Var,
  70. match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
  71. default: VAR_TYPE | Var[VAR_TYPE],
  72. ) -> Var[VAR_TYPE]:
  73. """Create the match var.
  74. Args:
  75. match_cond_var: The match condition var.
  76. match_cases: The match cases.
  77. default: The default case.
  78. Returns:
  79. The match var.
  80. """
  81. return MatchOperation.create(match_cond_var, match_cases, default)
  82. def match(
  83. cond: Any,
  84. *cases: Unpack[
  85. tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE]
  86. ],
  87. ) -> Var[VAR_TYPE]:
  88. """Create a match var.
  89. Args:
  90. cond: The condition to match.
  91. cases: The match cases. Each case should be a tuple with the first elements as the match case and the last element as the return value. The last argument should be the default case.
  92. Returns:
  93. The match var.
  94. Raises:
  95. ValueError: If the default case is not the last case or the tuple elements are less than 2.
  96. """
  97. default = types.Unset()
  98. if len([case for case in cases if not isinstance(case, tuple)]) > 1:
  99. raise ValueError("rx.match can only have one default case.")
  100. if not cases:
  101. raise ValueError("rx.match should have at least one case.")
  102. # Get the default case which should be the last non-tuple arg
  103. if not isinstance(cases[-1], tuple):
  104. default = cases[-1]
  105. actual_cases = cases[:-1]
  106. else:
  107. actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
  108. _process_match_cases(actual_cases)
  109. is_component_match = _validate_return_types(
  110. *[case[-1] for case in actual_cases],
  111. *([default] if not isinstance(default, types.Unset) else []),
  112. )
  113. if isinstance(default, types.Unset) and not is_component_match:
  114. raise ValueError(
  115. "For cases with return types as Vars, a default case must be provided"
  116. )
  117. if isinstance(default, types.Unset):
  118. default = Fragment.create()
  119. default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
  120. return _create_match_var(
  121. cond,
  122. actual_cases,
  123. default,
  124. )