match.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """rx.match."""
  2. from typing import Any, Union, cast
  3. from typing_extensions import Unpack
  4. from reflex.components.base import Fragment
  5. from reflex.components.component import BaseComponent, Component
  6. from reflex.utils import types
  7. from reflex.utils.exceptions import MatchTypeError
  8. from reflex.vars.base import VAR_TYPE, Var
  9. from reflex.vars.number import MatchOperation
  10. CASE_TYPE = tuple[Unpack[tuple[Any, ...]], Var[VAR_TYPE] | VAR_TYPE]
  11. def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
  12. """Process the individual match cases.
  13. Args:
  14. cases: The match cases.
  15. Raises:
  16. ValueError: If the default case is not the last case or the tuple elements are less than 2.
  17. """
  18. for case in cases:
  19. if not isinstance(case, tuple):
  20. raise ValueError(
  21. "rx.match should have tuples of cases and a default case as the last argument."
  22. )
  23. # There should be at least two elements in a case tuple(a condition and return value)
  24. if len(case) < 2:
  25. raise ValueError(
  26. "A case tuple should have at least a match case element and a return value."
  27. )
  28. def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None:
  29. """Validate that match cases have the same return types.
  30. Args:
  31. match_cases: The match cases.
  32. Raises:
  33. MatchTypeError: If the return types of cases are different.
  34. """
  35. def is_component_or_component_var(obj: Any) -> bool:
  36. return types._isinstance(obj, BaseComponent) or (
  37. isinstance(obj, Var)
  38. and types.safe_typehint_issubclass(
  39. obj._var_type, Union[list[BaseComponent], BaseComponent]
  40. )
  41. )
  42. def type_of_return_type(obj: Any) -> Any:
  43. if isinstance(obj, Var):
  44. return obj._var_type
  45. return type(obj)
  46. return_types = [case[-1] for case in match_cases]
  47. if any(
  48. is_component_or_component_var(return_type) for return_type in return_types
  49. ) and not all(
  50. is_component_or_component_var(return_type) for return_type in return_types
  51. ):
  52. non_component_return_types = [
  53. (type_of_return_type(return_type), i)
  54. for i, return_type in enumerate(return_types)
  55. if not is_component_or_component_var(return_type)
  56. ]
  57. raise MatchTypeError(
  58. "Match cases should have the same return types. "
  59. + "Expected return types to be of type Component or Var[Component]. "
  60. + ". ".join(
  61. [
  62. f"Return type of case {i} is {return_type}"
  63. for return_type, i in non_component_return_types
  64. ]
  65. )
  66. )
  67. def _create_match_var(
  68. match_cond_var: Var,
  69. match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
  70. default: VAR_TYPE | Var[VAR_TYPE],
  71. ) -> Var[VAR_TYPE]:
  72. """Create the match var.
  73. Args:
  74. match_cond_var: The match condition var.
  75. match_cases: The match cases.
  76. default: The default case.
  77. Returns:
  78. The match var.
  79. """
  80. return MatchOperation.create(match_cond_var, match_cases, default)
  81. def match(
  82. cond: Any,
  83. *cases: Unpack[
  84. tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE]
  85. ],
  86. ) -> Var[VAR_TYPE]:
  87. """Create a match var.
  88. Args:
  89. cond: The condition to match.
  90. cases: The match cases. Each case should be a tuple with the first elements as the match case and the last element as the return value. The last argument should be the default case.
  91. Returns:
  92. The match var.
  93. Raises:
  94. ValueError: If the default case is not the last case or the tuple elements are less than 2.
  95. """
  96. default = None
  97. if len([case for case in cases if not isinstance(case, tuple)]) > 1:
  98. raise ValueError("rx.match can only have one default case.")
  99. if not cases:
  100. raise ValueError("rx.match should have at least one case.")
  101. # Get the default case which should be the last non-tuple arg
  102. if not isinstance(cases[-1], tuple):
  103. default = cases[-1]
  104. actual_cases = cases[:-1]
  105. else:
  106. actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
  107. _process_match_cases(actual_cases)
  108. _validate_return_types(actual_cases)
  109. if default is None and any(
  110. not (
  111. isinstance((return_type := case[-1]), Component)
  112. or (
  113. isinstance(return_type, Var)
  114. and types.typehint_issubclass(return_type._var_type, Component)
  115. )
  116. )
  117. for case in actual_cases
  118. ):
  119. raise ValueError(
  120. "For cases with return types as Vars, a default case must be provided"
  121. )
  122. elif default is None:
  123. default = Fragment.create()
  124. default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
  125. return _create_match_var(
  126. cond,
  127. actual_cases,
  128. default,
  129. )