utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. """Common utility functions used in the compiler."""
  2. from __future__ import annotations
  3. import os
  4. from typing import Any, Callable, Dict, Optional, Type, Union
  5. from urllib.parse import urlparse
  6. from reflex import constants
  7. from reflex.components.base import (
  8. Body,
  9. Description,
  10. DocumentHead,
  11. Head,
  12. Html,
  13. Image,
  14. Main,
  15. Meta,
  16. NextScript,
  17. Title,
  18. )
  19. from reflex.components.component import Component, ComponentStyle, CustomComponent
  20. from reflex.state import BaseState, Cookie, LocalStorage
  21. from reflex.style import Style
  22. from reflex.utils import console, format, imports, path_ops
  23. from reflex.vars import Var
  24. # To re-export this function.
  25. merge_imports = imports.merge_imports
  26. def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
  27. """Compile an import statement.
  28. Args:
  29. fields: The set of fields to import from the library.
  30. Returns:
  31. The libraries for default and rest.
  32. default: default library. When install "import def from library".
  33. rest: rest of libraries. When install "import {rest1, rest2} from library"
  34. """
  35. # ignore the ImportVar fields with render=False during compilation
  36. fields_set = {field for field in fields if field.render}
  37. # Check for default imports.
  38. defaults = {field for field in fields_set if field.is_default}
  39. assert len(defaults) < 2
  40. # Get the default import, and the specific imports.
  41. default = next(iter({field.name for field in defaults}), "")
  42. rest = {field.name for field in fields_set - defaults}
  43. return default, list(rest)
  44. def validate_imports(import_dict: imports.ImportDict):
  45. """Verify that the same Tag is not used in multiple import.
  46. Args:
  47. import_dict: The dict of imports to validate
  48. Raises:
  49. ValueError: if a conflict on "tag/alias" is detected for an import.
  50. """
  51. used_tags = {}
  52. for lib, _imports in import_dict.items():
  53. for _import in _imports:
  54. import_name = (
  55. f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag
  56. )
  57. if import_name in used_tags:
  58. raise ValueError(
  59. f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}"
  60. )
  61. if import_name is not None:
  62. used_tags[import_name] = lib
  63. def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
  64. """Compile an import dict.
  65. Args:
  66. import_dict: The import dict to compile.
  67. Returns:
  68. The list of import dict.
  69. """
  70. collapsed_import_dict = imports.collapse_imports(import_dict)
  71. validate_imports(collapsed_import_dict)
  72. import_dicts = []
  73. for lib, fields in collapsed_import_dict.items():
  74. default, rest = compile_import_statement(fields)
  75. # prevent lib from being rendered on the page if all imports are non rendered kind
  76. if not any({f.render for f in fields}): # type: ignore
  77. continue
  78. if not lib:
  79. assert not default, "No default field allowed for empty library."
  80. assert rest is not None and len(rest) > 0, "No fields to import."
  81. for module in sorted(rest):
  82. import_dicts.append(get_import_dict(module))
  83. continue
  84. # remove the version before rendering the package imports
  85. lib = format.format_library_name(lib)
  86. import_dicts.append(get_import_dict(lib, default, rest))
  87. return import_dicts
  88. def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) -> dict:
  89. """Get dictionary for import template.
  90. Args:
  91. lib: The importing react library.
  92. default: The default module to import.
  93. rest: The rest module to import.
  94. Returns:
  95. A dictionary for import template.
  96. """
  97. return {
  98. "lib": lib,
  99. "default": default,
  100. "rest": rest if rest else [],
  101. }
  102. def compile_state(state: Type[BaseState]) -> dict:
  103. """Compile the state of the app.
  104. Args:
  105. state: The app state object.
  106. Returns:
  107. A dictionary of the compiled state.
  108. """
  109. try:
  110. initial_state = state(_reflex_internal_init=True).dict(initial=True)
  111. except Exception as e:
  112. console.warn(
  113. f"Failed to compile initial state with computed vars, excluding them: {e}"
  114. )
  115. initial_state = state(_reflex_internal_init=True).dict(include_computed=False)
  116. return format.format_state(initial_state)
  117. def _compile_client_storage_field(
  118. field,
  119. ) -> tuple[Type[Cookie] | Type[LocalStorage] | None, dict[str, Any] | None]:
  120. """Compile the given cookie or local_storage field.
  121. Args:
  122. field: The possible cookie field to compile.
  123. Returns:
  124. A dictionary of the compiled cookie or None if the field is not cookie-like.
  125. """
  126. for field_type in (Cookie, LocalStorage):
  127. if isinstance(field.default, field_type):
  128. cs_obj = field.default
  129. elif isinstance(field.annotation, type) and issubclass(
  130. field.annotation, field_type
  131. ):
  132. cs_obj = field.annotation()
  133. else:
  134. continue
  135. return field_type, cs_obj.options()
  136. return None, None
  137. def _compile_client_storage_recursive(
  138. state: Type[BaseState],
  139. ) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
  140. """Compile the client-side storage for the given state recursively.
  141. Args:
  142. state: The app state object.
  143. Returns:
  144. A tuple of the compiled client-side storage info:
  145. (
  146. cookies: dict[str, dict],
  147. local_storage: dict[str, dict[str, str]]
  148. )
  149. """
  150. cookies = {}
  151. local_storage = {}
  152. state_name = state.get_full_name()
  153. for name, field in state.model_fields.items():
  154. if name in state.inherited_vars:
  155. # only include vars defined in this state
  156. continue
  157. state_key = f"{state_name}.{name}"
  158. field_type, options = _compile_client_storage_field(field)
  159. if field_type is Cookie:
  160. cookies[state_key] = options
  161. elif field_type is LocalStorage:
  162. local_storage[state_key] = options
  163. else:
  164. continue
  165. for substate in state.get_substates():
  166. substate_cookies, substate_local_storage = _compile_client_storage_recursive(
  167. substate
  168. )
  169. cookies.update(substate_cookies)
  170. local_storage.update(substate_local_storage)
  171. return cookies, local_storage
  172. def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
  173. """Compile the client-side storage for the given state.
  174. Args:
  175. state: The app state object.
  176. Returns:
  177. A dictionary of the compiled client-side storage info.
  178. """
  179. cookies, local_storage = _compile_client_storage_recursive(state)
  180. return {
  181. constants.COOKIES: cookies,
  182. constants.LOCAL_STORAGE: local_storage,
  183. }
  184. def compile_custom_component(
  185. component: CustomComponent,
  186. ) -> tuple[dict, imports.ImportDict]:
  187. """Compile a custom component.
  188. Args:
  189. component: The custom component to compile.
  190. Returns:
  191. A tuple of the compiled component and the imports required by the component.
  192. """
  193. # Render the component.
  194. render = component.get_component(component)
  195. # Get the imports.
  196. imports = {
  197. lib: fields
  198. for lib, fields in render.get_imports().items()
  199. if lib != component.library
  200. }
  201. # Concatenate the props.
  202. props = [prop._var_name for prop in component.get_prop_vars()]
  203. # Compile the component.
  204. return (
  205. {
  206. "name": component.tag,
  207. "props": props,
  208. "render": render.render(),
  209. "hooks": render.get_hooks(),
  210. "custom_code": render.get_custom_code(),
  211. },
  212. imports,
  213. )
  214. def create_document_root(
  215. head_components: list[Component] | None = None,
  216. html_lang: Optional[str] = None,
  217. html_custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
  218. ) -> Component:
  219. """Create the document root.
  220. Args:
  221. head_components: The components to add to the head.
  222. html_lang: The language of the document, will be added to the html root element.
  223. html_custom_attrs: custom attributes added to the html root element.
  224. Returns:
  225. The document root.
  226. """
  227. head_components = head_components or []
  228. return Html.create(
  229. DocumentHead.create(*head_components),
  230. Body.create(
  231. Main.create(),
  232. NextScript.create(),
  233. ),
  234. lang=html_lang or "en",
  235. custom_attrs=html_custom_attrs or {},
  236. )
  237. def create_theme(style: ComponentStyle) -> dict:
  238. """Create the base style for the app.
  239. Args:
  240. style: The style dict for the app.
  241. Returns:
  242. The base style for the app.
  243. """
  244. # Get the global style from the style dict.
  245. style_rules = Style({k: v for k, v in style.items() if not isinstance(k, Callable)})
  246. root_style = {
  247. # Root styles.
  248. ":root": Style(
  249. {f"*{k}": v for k, v in style_rules.items() if k.startswith(":")}
  250. ),
  251. # Body styles.
  252. "body": Style(
  253. {k: v for k, v in style_rules.items() if not k.startswith(":")},
  254. ),
  255. }
  256. # Return the theme.
  257. return {"styles": {"global": root_style}}
  258. def get_page_path(path: str) -> str:
  259. """Get the path of the compiled JS file for the given page.
  260. Args:
  261. path: The path of the page.
  262. Returns:
  263. The path of the compiled JS file.
  264. """
  265. return os.path.join(constants.Dirs.WEB_PAGES, path + constants.Ext.JS)
  266. def get_theme_path() -> str:
  267. """Get the path of the base theme style.
  268. Returns:
  269. The path of the theme style.
  270. """
  271. return os.path.join(
  272. constants.Dirs.WEB_UTILS, constants.PageNames.THEME + constants.Ext.JS
  273. )
  274. def get_root_stylesheet_path() -> str:
  275. """Get the path of the app root file.
  276. Returns:
  277. The path of the app root file.
  278. """
  279. return os.path.join(
  280. constants.STYLES_DIR, constants.PageNames.STYLESHEET_ROOT + constants.Ext.CSS
  281. )
  282. def get_context_path() -> str:
  283. """Get the path of the context / initial state file.
  284. Returns:
  285. The path of the context module.
  286. """
  287. return os.path.join(
  288. constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS
  289. )
  290. def get_components_path() -> str:
  291. """Get the path of the compiled components.
  292. Returns:
  293. The path of the compiled components.
  294. """
  295. return os.path.join(constants.Dirs.WEB_UTILS, "components" + constants.Ext.JS)
  296. def get_stateful_components_path() -> str:
  297. """Get the path of the compiled stateful components.
  298. Returns:
  299. The path of the compiled stateful components.
  300. """
  301. return os.path.join(
  302. constants.Dirs.WEB_UTILS,
  303. constants.PageNames.STATEFUL_COMPONENTS + constants.Ext.JS,
  304. )
  305. def get_asset_path(filename: str | None = None) -> str:
  306. """Get the path for an asset.
  307. Args:
  308. filename: If given, is added to the root path of assets dir.
  309. Returns:
  310. The path of the asset.
  311. """
  312. console.deprecate(
  313. feature_name="rx.get_asset_path",
  314. reason="use rx.get_upload_dir() instead.",
  315. deprecation_version="0.4.0",
  316. removal_version="0.5.0",
  317. )
  318. if filename is None:
  319. return constants.Dirs.WEB_ASSETS
  320. else:
  321. return os.path.join(constants.Dirs.WEB_ASSETS, filename)
  322. def add_meta(
  323. page: Component, title: str, image: str, description: str, meta: list[dict]
  324. ) -> Component:
  325. """Add metadata to a page.
  326. Args:
  327. page: The component for the page.
  328. title: The title of the page.
  329. image: The image for the page.
  330. description: The description of the page.
  331. meta: The metadata list.
  332. Returns:
  333. The component with the metadata added.
  334. """
  335. meta_tags = [Meta.create(**item) for item in meta]
  336. page.children.append(
  337. Head.create(
  338. Title.create(title),
  339. Description.create(content=description),
  340. Image.create(content=image),
  341. *meta_tags,
  342. )
  343. )
  344. return page
  345. def write_page(path: str, code: str):
  346. """Write the given code to the given path.
  347. Args:
  348. path: The path to write the code to.
  349. code: The code to write.
  350. """
  351. path_ops.mkdir(os.path.dirname(path))
  352. with open(path, "w", encoding="utf-8") as f:
  353. f.write(code)
  354. def empty_dir(path: str, keep_files: list[str] | None = None):
  355. """Remove all files and folders in a directory except for the keep_files.
  356. Args:
  357. path: The path to the directory that will be emptied
  358. keep_files: List of filenames or foldernames that will not be deleted.
  359. """
  360. # If the directory does not exist, return.
  361. if not os.path.exists(path):
  362. return
  363. # Remove all files and folders in the directory.
  364. keep_files = keep_files or []
  365. directory_contents = os.listdir(path)
  366. for element in directory_contents:
  367. if element not in keep_files:
  368. path_ops.rm(os.path.join(path, element))
  369. def is_valid_url(url) -> bool:
  370. """Check if a url is valid.
  371. Args:
  372. url: The Url to check.
  373. Returns:
  374. Whether url is valid.
  375. """
  376. result = urlparse(url)
  377. return all([result.scheme, result.netloc])