utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. """Common utility functions used in the compiler."""
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from typing import Any, Callable, Dict, Optional, Type, Union
  5. from urllib.parse import urlparse
  6. from reflex.utils.prerequisites import get_web_dir
  7. from reflex.vars.base import Var
  8. try:
  9. from pydantic.v1.fields import ModelField
  10. except ModuleNotFoundError:
  11. from pydantic.fields import ModelField # type: ignore
  12. from reflex import constants
  13. from reflex.components.base import (
  14. Body,
  15. Description,
  16. DocumentHead,
  17. Head,
  18. Html,
  19. Image,
  20. Main,
  21. Meta,
  22. NextScript,
  23. Title,
  24. )
  25. from reflex.components.component import Component, ComponentStyle, CustomComponent
  26. from reflex.state import BaseState, Cookie, LocalStorage, SessionStorage
  27. from reflex.style import Style
  28. from reflex.utils import console, format, imports, path_ops
  29. from reflex.utils.imports import ImportVar, ParsedImportDict
  30. # To re-export this function.
  31. merge_imports = imports.merge_imports
  32. def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
  33. """Compile an import statement.
  34. Args:
  35. fields: The set of fields to import from the library.
  36. Raises:
  37. ValueError: If there is more than one default import.
  38. Returns:
  39. The libraries for default and rest.
  40. default: default library. When install "import def from library".
  41. rest: rest of libraries. When install "import {rest1, rest2} from library"
  42. """
  43. # ignore the ImportVar fields with render=False during compilation
  44. fields_set = {field for field in fields if field.render}
  45. # Check for default imports.
  46. defaults = {field for field in fields_set if field.is_default}
  47. if len(defaults) >= 2:
  48. raise ValueError("Only one default import is allowed.")
  49. # Get the default import, and the specific imports.
  50. default = next(iter({field.name for field in defaults}), "")
  51. rest = {field.name for field in fields_set - defaults}
  52. return default, list(rest)
  53. def validate_imports(import_dict: ParsedImportDict):
  54. """Verify that the same Tag is not used in multiple import.
  55. Args:
  56. import_dict: The dict of imports to validate
  57. Raises:
  58. ValueError: if a conflict on "tag/alias" is detected for an import.
  59. """
  60. used_tags = {}
  61. for lib, _imports in import_dict.items():
  62. for _import in _imports:
  63. import_name = (
  64. f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag
  65. )
  66. if import_name in used_tags:
  67. raise ValueError(
  68. f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}"
  69. )
  70. if import_name is not None:
  71. used_tags[import_name] = lib
  72. def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
  73. """Compile an import dict.
  74. Args:
  75. import_dict: The import dict to compile.
  76. Raises:
  77. ValueError: If an import in the dict is invalid.
  78. Returns:
  79. The list of import dict.
  80. """
  81. collapsed_import_dict: ParsedImportDict = imports.collapse_imports(import_dict)
  82. validate_imports(collapsed_import_dict)
  83. import_dicts = []
  84. for lib, fields in collapsed_import_dict.items():
  85. default, rest = compile_import_statement(fields)
  86. # prevent lib from being rendered on the page if all imports are non rendered kind
  87. if not any({f.render for f in fields}): # type: ignore
  88. continue
  89. if not lib:
  90. if default:
  91. raise ValueError("No default field allowed for empty library.")
  92. if rest is None or len(rest) == 0:
  93. raise ValueError("No fields to import.")
  94. for module in sorted(rest):
  95. import_dicts.append(get_import_dict(module))
  96. continue
  97. # remove the version before rendering the package imports
  98. lib = format.format_library_name(lib)
  99. import_dicts.append(get_import_dict(lib, default, rest))
  100. return import_dicts
  101. def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) -> dict:
  102. """Get dictionary for import template.
  103. Args:
  104. lib: The importing react library.
  105. default: The default module to import.
  106. rest: The rest module to import.
  107. Returns:
  108. A dictionary for import template.
  109. """
  110. return {
  111. "lib": lib,
  112. "default": default,
  113. "rest": rest if rest else [],
  114. }
  115. def compile_state(state: Type[BaseState]) -> dict:
  116. """Compile the state of the app.
  117. Args:
  118. state: The app state object.
  119. Returns:
  120. A dictionary of the compiled state.
  121. """
  122. try:
  123. initial_state = state(_reflex_internal_init=True).dict(initial=True)
  124. except Exception as e:
  125. console.warn(
  126. f"Failed to compile initial state with computed vars, excluding them: {e}"
  127. )
  128. initial_state = state(_reflex_internal_init=True).dict(
  129. initial=True, include_computed=False
  130. )
  131. return initial_state
  132. def _compile_client_storage_field(
  133. field: ModelField,
  134. ) -> tuple[
  135. Type[Cookie] | Type[LocalStorage] | Type[SessionStorage] | None,
  136. dict[str, Any] | None,
  137. ]:
  138. """Compile the given cookie, local_storage or session_storage field.
  139. Args:
  140. field: The possible cookie field to compile.
  141. Returns:
  142. A dictionary of the compiled cookie or None if the field is not cookie-like.
  143. """
  144. for field_type in (Cookie, LocalStorage, SessionStorage):
  145. if isinstance(field.default, field_type):
  146. cs_obj = field.default
  147. elif isinstance(field.type_, type) and issubclass(field.type_, field_type):
  148. cs_obj = field.type_()
  149. else:
  150. continue
  151. return field_type, cs_obj.options()
  152. return None, None
  153. def _compile_client_storage_recursive(
  154. state: Type[BaseState],
  155. ) -> tuple[dict[str, dict], dict[str, dict], dict[str, dict]]:
  156. """Compile the client-side storage for the given state recursively.
  157. Args:
  158. state: The app state object.
  159. Returns:
  160. A tuple of the compiled client-side storage info:
  161. (
  162. cookies: dict[str, dict],
  163. local_storage: dict[str, dict[str, str]]
  164. session_storage: dict[str, dict[str, str]]
  165. ).
  166. """
  167. cookies = {}
  168. local_storage = {}
  169. session_storage = {}
  170. state_name = state.get_full_name()
  171. for name, field in state.__fields__.items():
  172. if name in state.inherited_vars:
  173. # only include vars defined in this state
  174. continue
  175. state_key = f"{state_name}.{name}"
  176. field_type, options = _compile_client_storage_field(field)
  177. if field_type is Cookie:
  178. cookies[state_key] = options
  179. elif field_type is LocalStorage:
  180. local_storage[state_key] = options
  181. elif field_type is SessionStorage:
  182. session_storage[state_key] = options
  183. else:
  184. continue
  185. for substate in state.get_substates():
  186. (
  187. substate_cookies,
  188. substate_local_storage,
  189. substate_session_storage,
  190. ) = _compile_client_storage_recursive(substate)
  191. cookies.update(substate_cookies)
  192. local_storage.update(substate_local_storage)
  193. session_storage.update(substate_session_storage)
  194. return cookies, local_storage, session_storage
  195. def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
  196. """Compile the client-side storage for the given state.
  197. Args:
  198. state: The app state object.
  199. Returns:
  200. A dictionary of the compiled client-side storage info.
  201. """
  202. cookies, local_storage, session_storage = _compile_client_storage_recursive(state)
  203. return {
  204. constants.COOKIES: cookies,
  205. constants.LOCAL_STORAGE: local_storage,
  206. constants.SESSION_STORAGE: session_storage,
  207. }
  208. def compile_custom_component(
  209. component: CustomComponent,
  210. ) -> tuple[dict, ParsedImportDict]:
  211. """Compile a custom component.
  212. Args:
  213. component: The custom component to compile.
  214. Returns:
  215. A tuple of the compiled component and the imports required by the component.
  216. """
  217. # Render the component.
  218. render = component.get_component(component)
  219. # Get the imports.
  220. imports: ParsedImportDict = {
  221. lib: fields
  222. for lib, fields in render._get_all_imports().items()
  223. if lib != component.library
  224. }
  225. # Concatenate the props.
  226. props = [prop._js_expr for prop in component.get_prop_vars()]
  227. # Compile the component.
  228. return (
  229. {
  230. "name": component.tag,
  231. "props": props,
  232. "render": render.render(),
  233. "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
  234. "custom_code": render._get_all_custom_code(),
  235. },
  236. imports,
  237. )
  238. def create_document_root(
  239. head_components: list[Component] | None = None,
  240. html_lang: Optional[str] = None,
  241. html_custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
  242. ) -> Component:
  243. """Create the document root.
  244. Args:
  245. head_components: The components to add to the head.
  246. html_lang: The language of the document, will be added to the html root element.
  247. html_custom_attrs: custom attributes added to the html root element.
  248. Returns:
  249. The document root.
  250. """
  251. head_components = head_components or []
  252. return Html.create(
  253. DocumentHead.create(*head_components),
  254. Body.create(
  255. Main.create(),
  256. NextScript.create(),
  257. ),
  258. lang=html_lang or "en",
  259. custom_attrs=html_custom_attrs or {},
  260. )
  261. def create_theme(style: ComponentStyle) -> dict:
  262. """Create the base style for the app.
  263. Args:
  264. style: The style dict for the app.
  265. Returns:
  266. The base style for the app.
  267. """
  268. # Get the global style from the style dict.
  269. style_rules = Style({k: v for k, v in style.items() if not isinstance(k, Callable)})
  270. root_style = {
  271. # Root styles.
  272. ":root": Style(
  273. {f"*{k}": v for k, v in style_rules.items() if k.startswith(":")}
  274. ),
  275. # Body styles.
  276. "body": Style(
  277. {k: v for k, v in style_rules.items() if not k.startswith(":")},
  278. ),
  279. }
  280. # Return the theme.
  281. return {"styles": {"global": root_style}}
  282. def get_page_path(path: str) -> str:
  283. """Get the path of the compiled JS file for the given page.
  284. Args:
  285. path: The path of the page.
  286. Returns:
  287. The path of the compiled JS file.
  288. """
  289. return str(get_web_dir() / constants.Dirs.PAGES / (path + constants.Ext.JS))
  290. def get_theme_path() -> str:
  291. """Get the path of the base theme style.
  292. Returns:
  293. The path of the theme style.
  294. """
  295. return str(
  296. get_web_dir()
  297. / constants.Dirs.UTILS
  298. / (constants.PageNames.THEME + constants.Ext.JS)
  299. )
  300. def get_root_stylesheet_path() -> str:
  301. """Get the path of the app root file.
  302. Returns:
  303. The path of the app root file.
  304. """
  305. return str(
  306. get_web_dir()
  307. / constants.Dirs.STYLES
  308. / (constants.PageNames.STYLESHEET_ROOT + constants.Ext.CSS)
  309. )
  310. def get_context_path() -> str:
  311. """Get the path of the context / initial state file.
  312. Returns:
  313. The path of the context module.
  314. """
  315. return str(get_web_dir() / (constants.Dirs.CONTEXTS_PATH + constants.Ext.JS))
  316. def get_components_path() -> str:
  317. """Get the path of the compiled components.
  318. Returns:
  319. The path of the compiled components.
  320. """
  321. return str(
  322. get_web_dir()
  323. / constants.Dirs.UTILS
  324. / (constants.PageNames.COMPONENTS + constants.Ext.JS),
  325. )
  326. def get_stateful_components_path() -> str:
  327. """Get the path of the compiled stateful components.
  328. Returns:
  329. The path of the compiled stateful components.
  330. """
  331. return str(
  332. get_web_dir()
  333. / constants.Dirs.UTILS
  334. / (constants.PageNames.STATEFUL_COMPONENTS + constants.Ext.JS)
  335. )
  336. def add_meta(
  337. page: Component,
  338. title: str,
  339. image: str,
  340. meta: list[dict],
  341. description: str | None = None,
  342. ) -> Component:
  343. """Add metadata to a page.
  344. Args:
  345. page: The component for the page.
  346. title: The title of the page.
  347. image: The image for the page.
  348. meta: The metadata list.
  349. description: The description of the page.
  350. Returns:
  351. The component with the metadata added.
  352. """
  353. meta_tags = [
  354. item if isinstance(item, Component) else Meta.create(**item) for item in meta
  355. ]
  356. children: list[Any] = [Title.create(title)]
  357. if description:
  358. children.append(Description.create(content=description))
  359. children.append(Image.create(content=image))
  360. page.children.append(
  361. Head.create(
  362. *children,
  363. *meta_tags,
  364. )
  365. )
  366. return page
  367. def write_page(path: str | Path, code: str):
  368. """Write the given code to the given path.
  369. Args:
  370. path: The path to write the code to.
  371. code: The code to write.
  372. """
  373. path = Path(path)
  374. path_ops.mkdir(path.parent)
  375. path.write_text(code, encoding="utf-8")
  376. def empty_dir(path: str | Path, keep_files: list[str] | None = None):
  377. """Remove all files and folders in a directory except for the keep_files.
  378. Args:
  379. path: The path to the directory that will be emptied
  380. keep_files: List of filenames or foldernames that will not be deleted.
  381. """
  382. path = Path(path)
  383. # If the directory does not exist, return.
  384. if not path.exists():
  385. return
  386. # Remove all files and folders in the directory.
  387. keep_files = keep_files or []
  388. for element in path.iterdir():
  389. if element.name not in keep_files:
  390. path_ops.rm(element)
  391. def is_valid_url(url) -> bool:
  392. """Check if a url is valid.
  393. Args:
  394. url: The Url to check.
  395. Returns:
  396. Whether url is valid.
  397. """
  398. result = urlparse(url)
  399. return all([result.scheme, result.netloc])