style.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. """Handle styling."""
  2. from __future__ import annotations
  3. from typing import Any, Literal, Tuple, Type
  4. from reflex import constants
  5. from reflex.components.core.breakpoints import Breakpoints, breakpoints_values
  6. from reflex.event import EventChain, EventHandler
  7. from reflex.utils import format
  8. from reflex.utils.exceptions import ReflexError
  9. from reflex.utils.imports import ImportVar
  10. from reflex.utils.types import get_origin
  11. from reflex.vars import VarData
  12. from reflex.vars.base import LiteralVar, Var
  13. from reflex.vars.function import FunctionVar
  14. from reflex.vars.object import ObjectVar
  15. SYSTEM_COLOR_MODE: str = "system"
  16. LIGHT_COLOR_MODE: str = "light"
  17. DARK_COLOR_MODE: str = "dark"
  18. LiteralColorMode = Literal["system", "light", "dark"]
  19. # Reference the global ColorModeContext
  20. color_mode_imports = {
  21. f"$/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
  22. "react": [ImportVar(tag="useContext")],
  23. }
  24. def _color_mode_var(_js_expr: str, _var_type: Type = str) -> Var:
  25. """Create a Var that destructs the _js_expr from ColorModeContext.
  26. Args:
  27. _js_expr: The name of the variable to get from ColorModeContext.
  28. _var_type: The type of the Var.
  29. Returns:
  30. The Var that resolves to the color mode.
  31. """
  32. return Var(
  33. _js_expr=_js_expr,
  34. _var_type=_var_type,
  35. _var_data=VarData(
  36. imports=color_mode_imports,
  37. hooks={f"const {{ {_js_expr} }} = useContext(ColorModeContext)": None},
  38. ),
  39. ).guess_type()
  40. def set_color_mode(
  41. new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None,
  42. ) -> Var[EventChain]:
  43. """Create an EventChain Var that sets the color mode to a specific value.
  44. Note: `set_color_mode` is not a real event and cannot be triggered from a
  45. backend event handler.
  46. Args:
  47. new_color_mode: The color mode to set.
  48. Returns:
  49. The EventChain Var that can be passed to an event trigger.
  50. """
  51. base_setter = _color_mode_var(
  52. _js_expr=constants.ColorMode.SET,
  53. _var_type=EventChain,
  54. )
  55. if new_color_mode is None:
  56. return base_setter
  57. if not isinstance(new_color_mode, Var):
  58. new_color_mode = LiteralVar.create(new_color_mode)
  59. return Var(
  60. f"() => {base_setter!s}({new_color_mode!s})",
  61. _var_data=VarData.merge(
  62. base_setter._get_all_var_data(), new_color_mode._get_all_var_data()
  63. ),
  64. ).to(FunctionVar, EventChain)
  65. # Var resolves to the current color mode for the app ("light", "dark" or "system")
  66. color_mode = _color_mode_var(_js_expr=constants.ColorMode.NAME)
  67. # Var resolves to the resolved color mode for the app ("light" or "dark")
  68. resolved_color_mode = _color_mode_var(_js_expr=constants.ColorMode.RESOLVED_NAME)
  69. # Var resolves to a function invocation that toggles the color mode
  70. toggle_color_mode = _color_mode_var(
  71. _js_expr=constants.ColorMode.TOGGLE,
  72. _var_type=EventChain,
  73. )
  74. STYLE_PROP_SHORTHAND_MAPPING = {
  75. "paddingX": ("paddingInlineStart", "paddingInlineEnd"),
  76. "paddingY": ("paddingTop", "paddingBottom"),
  77. "marginX": ("marginInlineStart", "marginInlineEnd"),
  78. "marginY": ("marginTop", "marginBottom"),
  79. "bg": ("background",),
  80. "bgColor": ("backgroundColor",),
  81. # Radix components derive their font from this CSS var, not inherited from body or class.
  82. "fontFamily": ("fontFamily", "--default-font-family"),
  83. }
  84. def media_query(breakpoint_expr: str):
  85. """Create a media query selector.
  86. Args:
  87. breakpoint_expr: The CSS expression representing the breakpoint.
  88. Returns:
  89. The media query selector used as a key in emotion css dict.
  90. """
  91. return f"@media screen and (min-width: {breakpoint_expr})"
  92. def convert_item(
  93. style_item: int | str | Var,
  94. ) -> tuple[str | Var, VarData | None]:
  95. """Format a single value in a style dictionary.
  96. Args:
  97. style_item: The style item to format.
  98. Returns:
  99. The formatted style item and any associated VarData.
  100. Raises:
  101. ReflexError: If an EventHandler is used as a style value
  102. """
  103. if isinstance(style_item, EventHandler):
  104. raise ReflexError(
  105. "EventHandlers cannot be used as style values. "
  106. "Please use a Var or a literal value."
  107. )
  108. if isinstance(style_item, Var):
  109. return style_item, style_item._get_all_var_data()
  110. # Otherwise, convert to Var to collapse VarData encoded in f-string.
  111. new_var = LiteralVar.create(style_item)
  112. var_data = new_var._get_all_var_data() if new_var is not None else None
  113. return new_var, var_data
  114. def convert_list(
  115. responsive_list: list[str | dict | Var],
  116. ) -> tuple[list[str | dict[str, Var | list | dict]], VarData | None]:
  117. """Format a responsive value list.
  118. Args:
  119. responsive_list: The raw responsive value list (one value per breakpoint).
  120. Returns:
  121. The recursively converted responsive value list and any associated VarData.
  122. """
  123. converted_value = []
  124. item_var_datas = []
  125. for responsive_item in responsive_list:
  126. if isinstance(responsive_item, dict):
  127. # Recursively format nested style dictionaries.
  128. item, item_var_data = convert(responsive_item)
  129. else:
  130. item, item_var_data = convert_item(responsive_item)
  131. converted_value.append(item)
  132. item_var_datas.append(item_var_data)
  133. return converted_value, VarData.merge(*item_var_datas)
  134. def convert(
  135. style_dict: dict[str, Var | dict | list | str],
  136. ) -> tuple[dict[str, str | list | dict], VarData | None]:
  137. """Format a style dictionary.
  138. Args:
  139. style_dict: The style dictionary to format.
  140. Returns:
  141. The formatted style dictionary.
  142. """
  143. var_data = None # Track import/hook data from any Vars in the style dict.
  144. out = {}
  145. def update_out_dict(
  146. return_value: Var | dict | list | str, keys_to_update: tuple[str, ...]
  147. ):
  148. for k in keys_to_update:
  149. out[k] = return_value
  150. for key, value in style_dict.items():
  151. keys = (
  152. format_style_key(key)
  153. if not isinstance(value, (dict, ObjectVar))
  154. or (
  155. isinstance(value, Breakpoints)
  156. and all(not isinstance(v, dict) for v in value.values())
  157. )
  158. or (
  159. isinstance(value, ObjectVar)
  160. and not issubclass(get_origin(value._var_type) or value._var_type, dict)
  161. )
  162. else (key,)
  163. )
  164. if isinstance(value, Var):
  165. return_val = value
  166. new_var_data = value._get_all_var_data()
  167. update_out_dict(return_val, keys)
  168. elif isinstance(value, dict):
  169. # Recursively format nested style dictionaries.
  170. return_val, new_var_data = convert(value)
  171. update_out_dict(return_val, keys)
  172. elif isinstance(value, list):
  173. # Responsive value is a list of dict or value
  174. return_val, new_var_data = convert_list(value)
  175. update_out_dict(return_val, keys)
  176. else:
  177. return_val, new_var_data = convert_item(value)
  178. update_out_dict(return_val, keys)
  179. # Combine all the collected VarData instances.
  180. var_data = VarData.merge(var_data, new_var_data)
  181. if isinstance(style_dict, Breakpoints):
  182. out = Breakpoints(out).factorize()
  183. return out, var_data
  184. def format_style_key(key: str) -> Tuple[str, ...]:
  185. """Convert style keys to camel case and convert shorthand
  186. styles names to their corresponding css names.
  187. Args:
  188. key: The style key to convert.
  189. Returns:
  190. Tuple of css style names corresponding to the key provided.
  191. """
  192. key = format.to_camel_case(key, allow_hyphens=True)
  193. return STYLE_PROP_SHORTHAND_MAPPING.get(key, (key,))
  194. class Style(dict):
  195. """A style dictionary."""
  196. def __init__(self, style_dict: dict | None = None, **kwargs):
  197. """Initialize the style.
  198. Args:
  199. style_dict: The style dictionary.
  200. kwargs: Other key value pairs to apply to the dict update.
  201. """
  202. if style_dict:
  203. style_dict.update(kwargs)
  204. else:
  205. style_dict = kwargs
  206. style_dict, self._var_data = convert(style_dict or {})
  207. super().__init__(style_dict)
  208. def update(self, style_dict: dict | None, **kwargs):
  209. """Update the style.
  210. Args:
  211. style_dict: The style dictionary.
  212. kwargs: Other key value pairs to apply to the dict update.
  213. """
  214. if not isinstance(style_dict, Style):
  215. converted_dict = type(self)(style_dict)
  216. else:
  217. converted_dict = style_dict
  218. if kwargs:
  219. if converted_dict is None:
  220. converted_dict = type(self)(kwargs)
  221. else:
  222. converted_dict.update(kwargs)
  223. # Combine our VarData with that of any Vars in the style_dict that was passed.
  224. self._var_data = VarData.merge(self._var_data, converted_dict._var_data)
  225. super().update(converted_dict)
  226. def __setitem__(self, key: str, value: Any):
  227. """Set an item in the style.
  228. Args:
  229. key: The key to set.
  230. value: The value to set.
  231. """
  232. # Create a Var to collapse VarData encoded in f-string.
  233. _var = LiteralVar.create(value)
  234. if _var is not None:
  235. # Carry the imports/hooks when setting a Var as a value.
  236. self._var_data = VarData.merge(
  237. getattr(self, "_var_data", None), _var._get_all_var_data()
  238. )
  239. super().__setitem__(key, value)
  240. def __or__(self, other: Style | dict) -> Style:
  241. """Combine two styles.
  242. Args:
  243. other: The other style to combine.
  244. Returns:
  245. The combined style.
  246. """
  247. other_var_data = None
  248. if not isinstance(other, Style):
  249. other_dict, other_var_data = convert(other)
  250. else:
  251. other_dict, other_var_data = other, other._var_data
  252. new_style = Style(super().__or__(other_dict))
  253. if self._var_data or other_var_data:
  254. new_style._var_data = VarData.merge(self._var_data, other_var_data)
  255. return new_style
  256. def _format_emotion_style_pseudo_selector(key: str) -> str:
  257. """Format a pseudo selector for emotion CSS-in-JS.
  258. Args:
  259. key: Underscore-prefixed or colon-prefixed pseudo selector key (_hover/:hover).
  260. Returns:
  261. A self-referential pseudo selector key (&:hover).
  262. """
  263. prefix = None
  264. if key.startswith("_"):
  265. prefix = "&:"
  266. key = key[1:]
  267. if key.startswith(":"):
  268. # Handle pseudo selectors and elements in native format.
  269. prefix = "&"
  270. if prefix is not None:
  271. return prefix + format.to_kebab_case(key)
  272. return key
  273. def format_as_emotion(style_dict: dict[str, Any]) -> Style | None:
  274. """Convert the style to an emotion-compatible CSS-in-JS dict.
  275. Args:
  276. style_dict: The style dict to convert.
  277. Returns:
  278. The emotion style dict.
  279. """
  280. _var_data = style_dict._var_data if isinstance(style_dict, Style) else None
  281. emotion_style = Style()
  282. for orig_key, value in style_dict.items():
  283. key = _format_emotion_style_pseudo_selector(orig_key)
  284. if isinstance(value, (Breakpoints, list)):
  285. if isinstance(value, Breakpoints):
  286. mbps = {
  287. media_query(bp): (
  288. bp_value if isinstance(bp_value, dict) else {key: bp_value}
  289. )
  290. for bp, bp_value in value.items()
  291. }
  292. else:
  293. # Apply media queries from responsive value list.
  294. mbps = {
  295. media_query([0, *breakpoints_values][bp]): (
  296. bp_value if isinstance(bp_value, dict) else {key: bp_value}
  297. )
  298. for bp, bp_value in enumerate(value)
  299. }
  300. if key.startswith("&:"):
  301. emotion_style[key] = mbps
  302. else:
  303. for mq, style_sub_dict in mbps.items():
  304. emotion_style.setdefault(mq, {}).update(style_sub_dict)
  305. elif isinstance(value, dict):
  306. # Recursively format nested style dictionaries.
  307. emotion_style[key] = format_as_emotion(value)
  308. else:
  309. emotion_style[key] = value
  310. if emotion_style:
  311. if _var_data is not None:
  312. emotion_style._var_data = VarData.merge(emotion_style._var_data, _var_data)
  313. return emotion_style
  314. def convert_dict_to_style_and_format_emotion(
  315. raw_dict: dict[str, Any],
  316. ) -> dict[str, Any] | None:
  317. """Convert a dict to a style dict and then format as emotion.
  318. Args:
  319. raw_dict: The dict to convert.
  320. Returns:
  321. The emotion dict.
  322. """
  323. return format_as_emotion(Style(raw_dict))
  324. STACK_CHILDREN_FULL_WIDTH = {
  325. "& :where(.rx-Stack)": {
  326. "width": "100%",
  327. },
  328. "& :where(.rx-Stack) > :where( "
  329. "div:not(.rt-Box, .rx-Upload, .rx-Html),"
  330. "input, select, textarea, table"
  331. ")": {
  332. "width": "100%",
  333. "flex_shrink": "1",
  334. },
  335. }