match.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. """rx.match."""
  2. import textwrap
  3. from typing import Any, Dict, List, Optional, Tuple, Union
  4. from reflex.components.base import Fragment
  5. from reflex.components.component import BaseComponent, Component, MemoizationLeaf
  6. from reflex.components.core.colors import Color
  7. from reflex.components.tags import MatchTag, Tag
  8. from reflex.style import Style
  9. from reflex.utils import format, imports, types
  10. from reflex.utils.exceptions import MatchTypeError
  11. from reflex.vars import BaseVar, Var, VarData
  12. class Match(MemoizationLeaf):
  13. """Match cases based on a condition."""
  14. # The condition to determine which case to match.
  15. cond: Var[Any]
  16. # The list of match cases to be matched.
  17. match_cases: List[Any] = []
  18. # The catchall case to match.
  19. default: Any
  20. @classmethod
  21. def create(cls, cond: Any, *cases) -> Union[Component, BaseVar]:
  22. """Create a Match Component.
  23. Args:
  24. cond: The condition to determine which case to match.
  25. cases: This list of cases to match.
  26. Returns:
  27. The match component.
  28. Raises:
  29. ValueError: When a default case is not provided for cases with Var return types.
  30. """
  31. match_cond_var = cls._create_condition_var(cond)
  32. cases, default = cls._process_cases(list(cases))
  33. match_cases = cls._process_match_cases(cases)
  34. cls._validate_return_types(match_cases)
  35. if default is None and types._issubclass(type(match_cases[0][-1]), BaseVar):
  36. raise ValueError(
  37. "For cases with return types as Vars, a default case must be provided"
  38. )
  39. return cls._create_match_cond_var_or_component(
  40. match_cond_var, match_cases, default
  41. )
  42. @classmethod
  43. def _create_condition_var(cls, cond: Any) -> BaseVar:
  44. """Convert the condition to a Var.
  45. Args:
  46. cond: The condition.
  47. Returns:
  48. The condition as a base var
  49. Raises:
  50. ValueError: If the condition is not provided.
  51. """
  52. match_cond_var = Var.create(cond, _var_is_string=isinstance(cond, str))
  53. if match_cond_var is None:
  54. raise ValueError("The condition must be set")
  55. return match_cond_var # type: ignore
  56. @classmethod
  57. def _process_cases(
  58. cls, cases: List
  59. ) -> Tuple[List, Optional[Union[BaseVar, BaseComponent]]]:
  60. """Process the list of match cases and the catchall default case.
  61. Args:
  62. cases: The list of match cases.
  63. Returns:
  64. The default case and the list of match case tuples.
  65. Raises:
  66. ValueError: If there are multiple default cases.
  67. """
  68. default = None
  69. if len([case for case in cases if not isinstance(case, tuple)]) > 1:
  70. raise ValueError("rx.match can only have one default case.")
  71. # Get the default case which should be the last non-tuple arg
  72. if not isinstance(cases[-1], tuple):
  73. default = cases.pop()
  74. default = (
  75. cls._create_case_var_with_var_data(default)
  76. if not isinstance(default, BaseComponent)
  77. else default
  78. )
  79. return cases, default # type: ignore
  80. @classmethod
  81. def _create_case_var_with_var_data(cls, case_element):
  82. """Convert a case element into a Var.If the case
  83. is a Style type, we extract the var data and merge it with the
  84. newly created Var.
  85. Args:
  86. case_element: The case element.
  87. Returns:
  88. The case element Var.
  89. """
  90. _var_data = case_element._var_data if isinstance(case_element, Style) else None # type: ignore
  91. case_element = Var.create(
  92. case_element,
  93. _var_is_string=isinstance(case_element, (str, Color)),
  94. )
  95. if _var_data is not None:
  96. case_element._var_data = VarData.merge(case_element._var_data, _var_data) # type: ignore
  97. return case_element
  98. @classmethod
  99. def _process_match_cases(cls, cases: List) -> List[List[BaseVar]]:
  100. """Process the individual match cases.
  101. Args:
  102. cases: The match cases.
  103. Returns:
  104. The processed match cases.
  105. Raises:
  106. ValueError: If the default case is not the last case or the tuple elements are less than 2.
  107. """
  108. match_cases = []
  109. for case in cases:
  110. if not isinstance(case, tuple):
  111. raise ValueError(
  112. "rx.match should have tuples of cases and a default case as the last argument."
  113. )
  114. # There should be at least two elements in a case tuple(a condition and return value)
  115. if len(case) < 2:
  116. raise ValueError(
  117. "A case tuple should have at least a match case element and a return value."
  118. )
  119. case_list = []
  120. for element in case:
  121. # convert all non component element to vars.
  122. el = (
  123. cls._create_case_var_with_var_data(element)
  124. if not isinstance(element, BaseComponent)
  125. else element
  126. )
  127. if not isinstance(el, (BaseVar, BaseComponent)):
  128. raise ValueError("Case element must be a var or component")
  129. case_list.append(el)
  130. match_cases.append(case_list)
  131. return match_cases
  132. @classmethod
  133. def _validate_return_types(cls, match_cases: List[List[BaseVar]]) -> None:
  134. """Validate that match cases have the same return types.
  135. Args:
  136. match_cases: The match cases.
  137. Raises:
  138. MatchTypeError: If the return types of cases are different.
  139. """
  140. first_case_return = match_cases[0][-1]
  141. return_type = type(first_case_return)
  142. if types._isinstance(first_case_return, BaseComponent):
  143. return_type = BaseComponent
  144. elif types._isinstance(first_case_return, BaseVar):
  145. return_type = BaseVar
  146. for index, case in enumerate(match_cases):
  147. if not types._issubclass(type(case[-1]), return_type):
  148. raise MatchTypeError(
  149. f"Match cases should have the same return types. Case {index} with return "
  150. f"value `{case[-1]._var_name if isinstance(case[-1], BaseVar) else textwrap.shorten(str(case[-1]), width=250)}`"
  151. f" of type {type(case[-1])!r} is not {return_type}"
  152. )
  153. @classmethod
  154. def _create_match_cond_var_or_component(
  155. cls,
  156. match_cond_var: Var,
  157. match_cases: List[List[BaseVar]],
  158. default: Optional[Union[BaseVar, BaseComponent]],
  159. ) -> Union[Component, BaseVar]:
  160. """Create and return the match condition var or component.
  161. Args:
  162. match_cond_var: The match condition.
  163. match_cases: The list of match cases.
  164. default: The default case.
  165. Returns:
  166. The match component wrapped in a fragment or the match var.
  167. Raises:
  168. ValueError: If the return types are not vars when creating a match var for Var types.
  169. """
  170. if default is None and types._issubclass(
  171. type(match_cases[0][-1]), BaseComponent
  172. ):
  173. default = Fragment.create()
  174. if types._issubclass(type(match_cases[0][-1]), BaseComponent):
  175. return Fragment.create(
  176. cls(
  177. cond=match_cond_var,
  178. match_cases=match_cases,
  179. default=default,
  180. children=[case[-1] for case in match_cases] + [default], # type: ignore
  181. )
  182. )
  183. # Validate the match cases (as well as the default case) to have Var return types.
  184. if any(
  185. case for case in match_cases if not types._isinstance(case[-1], BaseVar)
  186. ) or not types._isinstance(default, BaseVar):
  187. raise ValueError("Return types of match cases should be Vars.")
  188. # match cases and default should all be Vars at this point.
  189. # Retrieve var data of every var in the match cases and default.
  190. var_data = [
  191. *[el._var_data for case in match_cases for el in case],
  192. default._var_data, # type: ignore
  193. ]
  194. return match_cond_var._replace(
  195. _var_name=format.format_match(
  196. cond=match_cond_var._var_name_unwrapped,
  197. match_cases=match_cases, # type: ignore
  198. default=default, # type: ignore
  199. ),
  200. _var_type=default._var_type, # type: ignore
  201. _var_is_local=False,
  202. _var_full_name_needs_state_prefix=False,
  203. _var_is_string=False,
  204. merge_var_data=VarData.merge(*var_data),
  205. )
  206. def _render(self) -> Tag:
  207. return MatchTag(
  208. cond=self.cond, match_cases=self.match_cases, default=self.default
  209. )
  210. def render(self) -> Dict:
  211. """Render the component.
  212. Returns:
  213. The dictionary for template of component.
  214. """
  215. tag = self._render()
  216. tag.name = "match"
  217. return dict(tag)
  218. def _get_imports_list(self) -> list[imports.ImportVar]:
  219. return [
  220. *super()._get_imports_list(),
  221. *getattr(self.cond._var_data, "imports", []),
  222. ]
  223. def _apply_theme(self, theme: Component):
  224. """Apply the theme to this component.
  225. Args:
  226. theme: The theme to apply.
  227. """
  228. # apply theme to return components.
  229. for match_case in self.match_cases:
  230. if isinstance(match_case[-1], Component):
  231. match_case[-1].apply_theme(theme)
  232. # apply theme to default component
  233. if isinstance(self.default, Component):
  234. self.default.apply_theme(theme)