style.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. """Handle styling."""
  2. from __future__ import annotations
  3. from typing import Any, Literal, Tuple, Type
  4. from reflex import constants
  5. from reflex.components.core.breakpoints import Breakpoints, breakpoints_values
  6. from reflex.event import EventChain
  7. from reflex.utils import format
  8. from reflex.utils.imports import ImportVar
  9. from reflex.vars import BaseVar, CallableVar, Var, VarData
  10. VarData.update_forward_refs() # Ensure all type definitions are resolved
  11. SYSTEM_COLOR_MODE: str = "system"
  12. LIGHT_COLOR_MODE: str = "light"
  13. DARK_COLOR_MODE: str = "dark"
  14. LiteralColorMode = Literal["system", "light", "dark"]
  15. # Reference the global ColorModeContext
  16. color_mode_imports = {
  17. f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
  18. "react": [ImportVar(tag="useContext")],
  19. }
  20. def _color_mode_var(_var_name: str, _var_type: Type = str) -> BaseVar:
  21. """Create a Var that destructs the _var_name from ColorModeContext.
  22. Args:
  23. _var_name: The name of the variable to get from ColorModeContext.
  24. _var_type: The type of the Var.
  25. Returns:
  26. The BaseVar for accessing _var_name from ColorModeContext.
  27. """
  28. return BaseVar(
  29. _var_name=_var_name,
  30. _var_type=_var_type,
  31. _var_is_local=False,
  32. _var_is_string=False,
  33. _var_data=VarData(
  34. imports=color_mode_imports,
  35. hooks={f"const {{ {_var_name} }} = useContext(ColorModeContext)": None},
  36. ),
  37. )
  38. @CallableVar
  39. def set_color_mode(
  40. new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None,
  41. ) -> BaseVar[EventChain]:
  42. """Create an EventChain Var that sets the color mode to a specific value.
  43. Note: `set_color_mode` is not a real event and cannot be triggered from a
  44. backend event handler.
  45. Args:
  46. new_color_mode: The color mode to set.
  47. Returns:
  48. The EventChain Var that can be passed to an event trigger.
  49. """
  50. base_setter = _color_mode_var(
  51. _var_name=constants.ColorMode.SET,
  52. _var_type=EventChain,
  53. )
  54. if new_color_mode is None:
  55. return base_setter
  56. if not isinstance(new_color_mode, Var):
  57. new_color_mode = Var.create_safe(new_color_mode, _var_is_string=True)
  58. return base_setter._replace(
  59. _var_name=f"() => {base_setter._var_name}({new_color_mode._var_name_unwrapped})",
  60. merge_var_data=new_color_mode._var_data,
  61. )
  62. # Var resolves to the current color mode for the app ("light", "dark" or "system")
  63. color_mode = _color_mode_var(_var_name=constants.ColorMode.NAME)
  64. # Var resolves to the resolved color mode for the app ("light" or "dark")
  65. resolved_color_mode = _color_mode_var(_var_name=constants.ColorMode.RESOLVED_NAME)
  66. # Var resolves to a function invocation that toggles the color mode
  67. toggle_color_mode = _color_mode_var(
  68. _var_name=constants.ColorMode.TOGGLE,
  69. _var_type=EventChain,
  70. )
  71. STYLE_PROP_SHORTHAND_MAPPING = {
  72. "paddingX": ("paddingInlineStart", "paddingInlineEnd"),
  73. "paddingY": ("paddingTop", "paddingBottom"),
  74. "marginX": ("marginInlineStart", "marginInlineEnd"),
  75. "marginY": ("marginTop", "marginBottom"),
  76. "bg": ("background",),
  77. "bgColor": ("backgroundColor",),
  78. # Radix components derive their font from this CSS var, not inherited from body or class.
  79. "fontFamily": ("fontFamily", "--default-font-family"),
  80. }
  81. def media_query(breakpoint_expr: str):
  82. """Create a media query selector.
  83. Args:
  84. breakpoint_expr: The CSS expression representing the breakpoint.
  85. Returns:
  86. The media query selector used as a key in emotion css dict.
  87. """
  88. return f"@media screen and (min-width: {breakpoint_expr})"
  89. def convert_item(style_item: str | Var) -> tuple[str, VarData | None]:
  90. """Format a single value in a style dictionary.
  91. Args:
  92. style_item: The style item to format.
  93. Returns:
  94. The formatted style item and any associated VarData.
  95. """
  96. if isinstance(style_item, Var):
  97. # If the value is a Var, extract the var_data and cast as str.
  98. return str(style_item), style_item._var_data
  99. # Otherwise, convert to Var to collapse VarData encoded in f-string.
  100. new_var = Var.create(style_item, _var_is_string=False)
  101. if new_var is not None and new_var._var_data:
  102. # The wrapped backtick is used to identify the Var for interpolation.
  103. return f"`{str(new_var)}`", new_var._var_data
  104. return style_item, None
  105. def convert_list(
  106. responsive_list: list[str | dict | Var],
  107. ) -> tuple[list[str | dict], VarData | None]:
  108. """Format a responsive value list.
  109. Args:
  110. responsive_list: The raw responsive value list (one value per breakpoint).
  111. Returns:
  112. The recursively converted responsive value list and any associated VarData.
  113. """
  114. converted_value = []
  115. item_var_datas = []
  116. for responsive_item in responsive_list:
  117. if isinstance(responsive_item, dict):
  118. # Recursively format nested style dictionaries.
  119. item, item_var_data = convert(responsive_item)
  120. else:
  121. item, item_var_data = convert_item(responsive_item)
  122. converted_value.append(item)
  123. item_var_datas.append(item_var_data)
  124. return converted_value, VarData.merge(*item_var_datas)
  125. def convert(style_dict):
  126. """Format a style dictionary.
  127. Args:
  128. style_dict: The style dictionary to format.
  129. Returns:
  130. The formatted style dictionary.
  131. """
  132. var_data = None # Track import/hook data from any Vars in the style dict.
  133. out = {}
  134. def update_out_dict(return_value, keys_to_update):
  135. for k in keys_to_update:
  136. out[k] = return_value
  137. for key, value in style_dict.items():
  138. keys = format_style_key(key)
  139. if isinstance(value, dict):
  140. # Recursively format nested style dictionaries.
  141. return_val, new_var_data = convert(value)
  142. update_out_dict(return_val, keys)
  143. elif isinstance(value, list):
  144. # Responsive value is a list of dict or value
  145. return_val, new_var_data = convert_list(value)
  146. update_out_dict(return_val, keys)
  147. else:
  148. return_val, new_var_data = convert_item(value)
  149. update_out_dict(return_val, keys)
  150. # Combine all the collected VarData instances.
  151. var_data = VarData.merge(var_data, new_var_data)
  152. if isinstance(style_dict, Breakpoints):
  153. out = Breakpoints(out).factorize()
  154. return out, var_data
  155. def format_style_key(key: str) -> Tuple[str, ...]:
  156. """Convert style keys to camel case and convert shorthand
  157. styles names to their corresponding css names.
  158. Args:
  159. key: The style key to convert.
  160. Returns:
  161. Tuple of css style names corresponding to the key provided.
  162. """
  163. key = format.to_camel_case(key, allow_hyphens=True)
  164. return STYLE_PROP_SHORTHAND_MAPPING.get(key, (key,))
  165. class Style(dict):
  166. """A style dictionary."""
  167. def __init__(self, style_dict: dict | None = None, **kwargs):
  168. """Initialize the style.
  169. Args:
  170. style_dict: The style dictionary.
  171. kwargs: Other key value pairs to apply to the dict update.
  172. """
  173. if style_dict:
  174. style_dict.update(kwargs)
  175. else:
  176. style_dict = kwargs
  177. style_dict, self._var_data = convert(style_dict or {})
  178. super().__init__(style_dict)
  179. def update(self, style_dict: dict | None, **kwargs):
  180. """Update the style.
  181. Args:
  182. style_dict: The style dictionary.
  183. kwargs: Other key value pairs to apply to the dict update.
  184. """
  185. if not isinstance(style_dict, Style):
  186. converted_dict = type(self)(style_dict)
  187. else:
  188. converted_dict = style_dict
  189. if kwargs:
  190. if converted_dict is None:
  191. converted_dict = type(self)(kwargs)
  192. else:
  193. converted_dict.update(kwargs)
  194. # Combine our VarData with that of any Vars in the style_dict that was passed.
  195. self._var_data = VarData.merge(self._var_data, converted_dict._var_data)
  196. super().update(converted_dict)
  197. def __setitem__(self, key: str, value: Any):
  198. """Set an item in the style.
  199. Args:
  200. key: The key to set.
  201. value: The value to set.
  202. """
  203. # Create a Var to collapse VarData encoded in f-string.
  204. _var = Var.create(value, _var_is_string=False)
  205. if _var is not None:
  206. # Carry the imports/hooks when setting a Var as a value.
  207. self._var_data = VarData.merge(self._var_data, _var._var_data)
  208. super().__setitem__(key, value)
  209. def _format_emotion_style_pseudo_selector(key: str) -> str:
  210. """Format a pseudo selector for emotion CSS-in-JS.
  211. Args:
  212. key: Underscore-prefixed or colon-prefixed pseudo selector key (_hover).
  213. Returns:
  214. A self-referential pseudo selector key (&:hover).
  215. """
  216. prefix = None
  217. if key.startswith("_"):
  218. # Handle pseudo selectors in chakra style format.
  219. prefix = "&:"
  220. key = key[1:]
  221. if key.startswith(":"):
  222. # Handle pseudo selectors and elements in native format.
  223. prefix = "&"
  224. if prefix is not None:
  225. return prefix + format.to_kebab_case(key)
  226. return key
  227. def format_as_emotion(style_dict: dict[str, Any]) -> Style | None:
  228. """Convert the style to an emotion-compatible CSS-in-JS dict.
  229. Args:
  230. style_dict: The style dict to convert.
  231. Returns:
  232. The emotion style dict.
  233. """
  234. _var_data = style_dict._var_data if isinstance(style_dict, Style) else None
  235. emotion_style = Style()
  236. for orig_key, value in style_dict.items():
  237. key = _format_emotion_style_pseudo_selector(orig_key)
  238. if isinstance(value, (Breakpoints, list)):
  239. if isinstance(value, Breakpoints):
  240. mbps = {
  241. media_query(bp): (
  242. bp_value if isinstance(bp_value, dict) else {key: bp_value}
  243. )
  244. for bp, bp_value in value.items()
  245. }
  246. else:
  247. # Apply media queries from responsive value list.
  248. mbps = {
  249. media_query([0, *breakpoints_values][bp]): (
  250. bp_value if isinstance(bp_value, dict) else {key: bp_value}
  251. )
  252. for bp, bp_value in enumerate(value)
  253. }
  254. if key.startswith("&:"):
  255. emotion_style[key] = mbps
  256. else:
  257. for mq, style_sub_dict in mbps.items():
  258. emotion_style.setdefault(mq, {}).update(style_sub_dict)
  259. elif isinstance(value, dict):
  260. # Recursively format nested style dictionaries.
  261. emotion_style[key] = format_as_emotion(value)
  262. else:
  263. emotion_style[key] = value
  264. if emotion_style:
  265. if _var_data is not None:
  266. emotion_style._var_data = VarData.merge(emotion_style._var_data, _var_data)
  267. return emotion_style
  268. def convert_dict_to_style_and_format_emotion(
  269. raw_dict: dict[str, Any],
  270. ) -> dict[str, Any] | None:
  271. """Convert a dict to a style dict and then format as emotion.
  272. Args:
  273. raw_dict: The dict to convert.
  274. Returns:
  275. The emotion dict.
  276. """
  277. return format_as_emotion(Style(raw_dict))
  278. STACK_CHILDREN_FULL_WIDTH = {
  279. "& :where(.rx-Stack)": {
  280. "width": "100%",
  281. },
  282. "& :where(.rx-Stack) > :where( "
  283. "div:not(.rt-Box, .rx-Upload, .rx-Html),"
  284. "input, select, textarea, table"
  285. ")": {
  286. "width": "100%",
  287. "flex_shrink": "1",
  288. },
  289. }