utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. """Common utility functions used in the compiler."""
  2. from __future__ import annotations
  3. import os
  4. from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
  5. from urllib.parse import urlparse
  6. try:
  7. # TODO The type checking guard can be removed once
  8. # reflex-hosting-cli tools are compatible with pydantic v2
  9. if not TYPE_CHECKING:
  10. from pydantic.v1.fields import ModelField
  11. else:
  12. raise ModuleNotFoundError
  13. except ModuleNotFoundError:
  14. from pydantic.fields import ModelField
  15. from reflex import constants
  16. from reflex.components.base import (
  17. Body,
  18. Description,
  19. DocumentHead,
  20. Head,
  21. Html,
  22. Image,
  23. Main,
  24. Meta,
  25. NextScript,
  26. Title,
  27. )
  28. from reflex.components.component import Component, ComponentStyle, CustomComponent
  29. from reflex.state import BaseState, Cookie, LocalStorage
  30. from reflex.style import Style
  31. from reflex.utils import console, format, imports, path_ops
  32. from reflex.vars import Var
  33. # To re-export this function.
  34. merge_imports = imports.merge_imports
  35. def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
  36. """Compile an import statement.
  37. Args:
  38. fields: The set of fields to import from the library.
  39. Returns:
  40. The libraries for default and rest.
  41. default: default library. When install "import def from library".
  42. rest: rest of libraries. When install "import {rest1, rest2} from library"
  43. """
  44. # ignore the ImportVar fields with render=False during compilation
  45. fields_set = {field for field in fields if field.render}
  46. # Check for default imports.
  47. defaults = {field for field in fields_set if field.is_default}
  48. assert len(defaults) < 2
  49. # Get the default import, and the specific imports.
  50. default = next(iter({field.name for field in defaults}), "")
  51. rest = {field.name for field in fields_set - defaults}
  52. return default, list(rest)
  53. def validate_imports(import_dict: imports.ImportDict):
  54. """Verify that the same Tag is not used in multiple import.
  55. Args:
  56. import_dict: The dict of imports to validate
  57. Raises:
  58. ValueError: if a conflict on "tag/alias" is detected for an import.
  59. """
  60. used_tags = {}
  61. for lib, _imports in import_dict.items():
  62. for _import in _imports:
  63. import_name = (
  64. f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag
  65. )
  66. if import_name in used_tags:
  67. raise ValueError(
  68. f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}"
  69. )
  70. if import_name is not None:
  71. used_tags[import_name] = lib
  72. def compile_imports(import_list: imports.ImportList) -> list[dict]:
  73. """Compile an import list.
  74. Args:
  75. import_list: The import list to compile.
  76. Returns:
  77. The list of template import dict.
  78. """
  79. collapsed_import_dict = import_list.collapse()
  80. validate_imports(collapsed_import_dict)
  81. import_dicts = []
  82. for lib, fields in collapsed_import_dict.items():
  83. default, rest = compile_import_statement(fields)
  84. # prevent lib from being rendered on the page if all imports are non rendered kind
  85. if not any({f.render for f in fields}): # type: ignore
  86. continue
  87. if not lib:
  88. assert not default, "No default field allowed for empty library."
  89. assert rest is not None and len(rest) > 0, "No fields to import."
  90. for module in sorted(rest):
  91. import_dicts.append(get_import_dict(module))
  92. continue
  93. import_dicts.append(get_import_dict(lib, default, rest))
  94. return import_dicts
  95. def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) -> dict:
  96. """Get dictionary for import template.
  97. Args:
  98. lib: The importing react library.
  99. default: The default module to import.
  100. rest: The rest module to import.
  101. Returns:
  102. A dictionary for import template.
  103. """
  104. return {
  105. "lib": lib,
  106. "default": default,
  107. "rest": rest if rest else [],
  108. }
  109. def compile_state(state: Type[BaseState]) -> dict:
  110. """Compile the state of the app.
  111. Args:
  112. state: The app state object.
  113. Returns:
  114. A dictionary of the compiled state.
  115. """
  116. try:
  117. initial_state = state(_reflex_internal_init=True).dict(initial=True)
  118. except Exception as e:
  119. console.warn(
  120. f"Failed to compile initial state with computed vars, excluding them: {e}"
  121. )
  122. initial_state = state(_reflex_internal_init=True).dict(include_computed=False)
  123. return format.format_state(initial_state)
  124. def _compile_client_storage_field(
  125. field: ModelField,
  126. ) -> tuple[Type[Cookie] | Type[LocalStorage] | None, dict[str, Any] | None]:
  127. """Compile the given cookie or local_storage field.
  128. Args:
  129. field: The possible cookie field to compile.
  130. Returns:
  131. A dictionary of the compiled cookie or None if the field is not cookie-like.
  132. """
  133. for field_type in (Cookie, LocalStorage):
  134. if isinstance(field.default, field_type):
  135. cs_obj = field.default
  136. elif isinstance(field.type_, type) and issubclass(field.type_, field_type):
  137. cs_obj = field.type_()
  138. else:
  139. continue
  140. return field_type, cs_obj.options()
  141. return None, None
  142. def _compile_client_storage_recursive(
  143. state: Type[BaseState],
  144. ) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
  145. """Compile the client-side storage for the given state recursively.
  146. Args:
  147. state: The app state object.
  148. Returns:
  149. A tuple of the compiled client-side storage info:
  150. (
  151. cookies: dict[str, dict],
  152. local_storage: dict[str, dict[str, str]]
  153. )
  154. """
  155. cookies = {}
  156. local_storage = {}
  157. state_name = state.get_full_name()
  158. for name, field in state.__fields__.items():
  159. if name in state.inherited_vars:
  160. # only include vars defined in this state
  161. continue
  162. state_key = f"{state_name}.{name}"
  163. field_type, options = _compile_client_storage_field(field)
  164. if field_type is Cookie:
  165. cookies[state_key] = options
  166. elif field_type is LocalStorage:
  167. local_storage[state_key] = options
  168. else:
  169. continue
  170. for substate in state.get_substates():
  171. substate_cookies, substate_local_storage = _compile_client_storage_recursive(
  172. substate
  173. )
  174. cookies.update(substate_cookies)
  175. local_storage.update(substate_local_storage)
  176. return cookies, local_storage
  177. def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
  178. """Compile the client-side storage for the given state.
  179. Args:
  180. state: The app state object.
  181. Returns:
  182. A dictionary of the compiled client-side storage info.
  183. """
  184. cookies, local_storage = _compile_client_storage_recursive(state)
  185. return {
  186. constants.COOKIES: cookies,
  187. constants.LOCAL_STORAGE: local_storage,
  188. }
  189. def compile_custom_component(
  190. component: CustomComponent,
  191. ) -> tuple[dict, imports.ImportList]:
  192. """Compile a custom component.
  193. Args:
  194. component: The custom component to compile.
  195. Returns:
  196. A tuple of the compiled component and the imports required by the component.
  197. """
  198. # Render the component.
  199. render = component.get_component(component)
  200. # Get the imports.
  201. component_library_name = format.format_library_name(component.library)
  202. _imports = imports.ImportList(
  203. imp
  204. for imp in render._get_all_imports()
  205. if imp.library != component_library_name
  206. )
  207. # Concatenate the props.
  208. props = [prop._var_name for prop in component.get_prop_vars()]
  209. # Compile the component.
  210. return (
  211. {
  212. "name": component.tag,
  213. "props": props,
  214. "render": render.render(),
  215. "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
  216. "custom_code": render._get_all_custom_code(),
  217. },
  218. _imports,
  219. )
  220. def create_document_root(
  221. head_components: list[Component] | None = None,
  222. html_lang: Optional[str] = None,
  223. html_custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
  224. ) -> Component:
  225. """Create the document root.
  226. Args:
  227. head_components: The components to add to the head.
  228. html_lang: The language of the document, will be added to the html root element.
  229. html_custom_attrs: custom attributes added to the html root element.
  230. Returns:
  231. The document root.
  232. """
  233. head_components = head_components or []
  234. return Html.create(
  235. DocumentHead.create(*head_components),
  236. Body.create(
  237. Main.create(),
  238. NextScript.create(),
  239. ),
  240. lang=html_lang or "en",
  241. custom_attrs=html_custom_attrs or {},
  242. )
  243. def create_theme(style: ComponentStyle) -> dict:
  244. """Create the base style for the app.
  245. Args:
  246. style: The style dict for the app.
  247. Returns:
  248. The base style for the app.
  249. """
  250. # Get the global style from the style dict.
  251. style_rules = Style({k: v for k, v in style.items() if not isinstance(k, Callable)})
  252. root_style = {
  253. # Root styles.
  254. ":root": Style(
  255. {f"*{k}": v for k, v in style_rules.items() if k.startswith(":")}
  256. ),
  257. # Body styles.
  258. "body": Style(
  259. {k: v for k, v in style_rules.items() if not k.startswith(":")},
  260. ),
  261. }
  262. # Return the theme.
  263. return {"styles": {"global": root_style}}
  264. def get_page_path(path: str) -> str:
  265. """Get the path of the compiled JS file for the given page.
  266. Args:
  267. path: The path of the page.
  268. Returns:
  269. The path of the compiled JS file.
  270. """
  271. return os.path.join(constants.Dirs.WEB_PAGES, path + constants.Ext.JS)
  272. def get_theme_path() -> str:
  273. """Get the path of the base theme style.
  274. Returns:
  275. The path of the theme style.
  276. """
  277. return os.path.join(
  278. constants.Dirs.WEB_UTILS, constants.PageNames.THEME + constants.Ext.JS
  279. )
  280. def get_root_stylesheet_path() -> str:
  281. """Get the path of the app root file.
  282. Returns:
  283. The path of the app root file.
  284. """
  285. return os.path.join(
  286. constants.STYLES_DIR, constants.PageNames.STYLESHEET_ROOT + constants.Ext.CSS
  287. )
  288. def get_context_path() -> str:
  289. """Get the path of the context / initial state file.
  290. Returns:
  291. The path of the context module.
  292. """
  293. return os.path.join(
  294. constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS
  295. )
  296. def get_components_path() -> str:
  297. """Get the path of the compiled components.
  298. Returns:
  299. The path of the compiled components.
  300. """
  301. return os.path.join(constants.Dirs.WEB_UTILS, "components" + constants.Ext.JS)
  302. def get_stateful_components_path() -> str:
  303. """Get the path of the compiled stateful components.
  304. Returns:
  305. The path of the compiled stateful components.
  306. """
  307. return os.path.join(
  308. constants.Dirs.WEB_UTILS,
  309. constants.PageNames.STATEFUL_COMPONENTS + constants.Ext.JS,
  310. )
  311. def get_asset_path(filename: str | None = None) -> str:
  312. """Get the path for an asset.
  313. Args:
  314. filename: If given, is added to the root path of assets dir.
  315. Returns:
  316. The path of the asset.
  317. """
  318. console.deprecate(
  319. feature_name="rx.get_asset_path",
  320. reason="use rx.get_upload_dir() instead.",
  321. deprecation_version="0.4.0",
  322. removal_version="0.5.0",
  323. )
  324. if filename is None:
  325. return constants.Dirs.WEB_ASSETS
  326. else:
  327. return os.path.join(constants.Dirs.WEB_ASSETS, filename)
  328. def add_meta(
  329. page: Component,
  330. title: str,
  331. image: str,
  332. meta: list[dict],
  333. description: str | None = None,
  334. ) -> Component:
  335. """Add metadata to a page.
  336. Args:
  337. page: The component for the page.
  338. title: The title of the page.
  339. image: The image for the page.
  340. meta: The metadata list.
  341. description: The description of the page.
  342. Returns:
  343. The component with the metadata added.
  344. """
  345. meta_tags = [Meta.create(**item) for item in meta]
  346. children: list[Any] = [
  347. Title.create(title),
  348. ]
  349. if description:
  350. children.append(Description.create(content=description))
  351. children.append(Image.create(content=image))
  352. page.children.append(
  353. Head.create(
  354. *children,
  355. *meta_tags,
  356. )
  357. )
  358. return page
  359. def write_page(path: str, code: str):
  360. """Write the given code to the given path.
  361. Args:
  362. path: The path to write the code to.
  363. code: The code to write.
  364. """
  365. path_ops.mkdir(os.path.dirname(path))
  366. with open(path, "w", encoding="utf-8") as f:
  367. f.write(code)
  368. def empty_dir(path: str, keep_files: list[str] | None = None):
  369. """Remove all files and folders in a directory except for the keep_files.
  370. Args:
  371. path: The path to the directory that will be emptied
  372. keep_files: List of filenames or foldernames that will not be deleted.
  373. """
  374. # If the directory does not exist, return.
  375. if not os.path.exists(path):
  376. return
  377. # Remove all files and folders in the directory.
  378. keep_files = keep_files or []
  379. directory_contents = os.listdir(path)
  380. for element in directory_contents:
  381. if element not in keep_files:
  382. path_ops.rm(os.path.join(path, element))
  383. def is_valid_url(url) -> bool:
  384. """Check if a url is valid.
  385. Args:
  386. url: The Url to check.
  387. Returns:
  388. Whether url is valid.
  389. """
  390. result = urlparse(url)
  391. return all([result.scheme, result.netloc])