utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. """Common utility functions used in the compiler."""
  2. from __future__ import annotations
  3. import os
  4. from typing import Any, Callable, Dict, Optional, Type, Union
  5. from urllib.parse import urlparse
  6. try:
  7. from pydantic.v1.fields import ModelField
  8. except ModuleNotFoundError:
  9. from pydantic.fields import ModelField # type: ignore
  10. from reflex import constants
  11. from reflex.components.base import (
  12. Body,
  13. Description,
  14. DocumentHead,
  15. Head,
  16. Html,
  17. Image,
  18. Main,
  19. Meta,
  20. NextScript,
  21. Title,
  22. )
  23. from reflex.components.component import Component, ComponentStyle, CustomComponent
  24. from reflex.state import BaseState, Cookie, LocalStorage
  25. from reflex.style import Style
  26. from reflex.utils import console, format, imports, path_ops
  27. from reflex.utils.imports import ImportVar, ParsedImportDict
  28. from reflex.vars import Var
  29. # To re-export this function.
  30. merge_imports = imports.merge_imports
  31. def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
  32. """Compile an import statement.
  33. Args:
  34. fields: The set of fields to import from the library.
  35. Returns:
  36. The libraries for default and rest.
  37. default: default library. When install "import def from library".
  38. rest: rest of libraries. When install "import {rest1, rest2} from library"
  39. """
  40. # ignore the ImportVar fields with render=False during compilation
  41. fields_set = {field for field in fields if field.render}
  42. # Check for default imports.
  43. defaults = {field for field in fields_set if field.is_default}
  44. assert len(defaults) < 2
  45. # Get the default import, and the specific imports.
  46. default = next(iter({field.name for field in defaults}), "")
  47. rest = {field.name for field in fields_set - defaults}
  48. return default, list(rest)
  49. def validate_imports(import_dict: ParsedImportDict):
  50. """Verify that the same Tag is not used in multiple import.
  51. Args:
  52. import_dict: The dict of imports to validate
  53. Raises:
  54. ValueError: if a conflict on "tag/alias" is detected for an import.
  55. """
  56. used_tags = {}
  57. for lib, _imports in import_dict.items():
  58. for _import in _imports:
  59. import_name = (
  60. f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag
  61. )
  62. if import_name in used_tags:
  63. raise ValueError(
  64. f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}"
  65. )
  66. if import_name is not None:
  67. used_tags[import_name] = lib
  68. def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
  69. """Compile an import dict.
  70. Args:
  71. import_dict: The import dict to compile.
  72. Returns:
  73. The list of import dict.
  74. """
  75. collapsed_import_dict: ParsedImportDict = imports.collapse_imports(import_dict)
  76. validate_imports(collapsed_import_dict)
  77. import_dicts = []
  78. for lib, fields in collapsed_import_dict.items():
  79. default, rest = compile_import_statement(fields)
  80. # prevent lib from being rendered on the page if all imports are non rendered kind
  81. if not any({f.render for f in fields}): # type: ignore
  82. continue
  83. if not lib:
  84. assert not default, "No default field allowed for empty library."
  85. assert rest is not None and len(rest) > 0, "No fields to import."
  86. for module in sorted(rest):
  87. import_dicts.append(get_import_dict(module))
  88. continue
  89. # remove the version before rendering the package imports
  90. lib = format.format_library_name(lib)
  91. import_dicts.append(get_import_dict(lib, default, rest))
  92. return import_dicts
  93. def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) -> dict:
  94. """Get dictionary for import template.
  95. Args:
  96. lib: The importing react library.
  97. default: The default module to import.
  98. rest: The rest module to import.
  99. Returns:
  100. A dictionary for import template.
  101. """
  102. return {
  103. "lib": lib,
  104. "default": default,
  105. "rest": rest if rest else [],
  106. }
  107. def compile_state(state: Type[BaseState]) -> dict:
  108. """Compile the state of the app.
  109. Args:
  110. state: The app state object.
  111. Returns:
  112. A dictionary of the compiled state.
  113. """
  114. try:
  115. initial_state = state(_reflex_internal_init=True).dict(initial=True)
  116. except Exception as e:
  117. console.warn(
  118. f"Failed to compile initial state with computed vars, excluding them: {e}"
  119. )
  120. initial_state = state(_reflex_internal_init=True).dict(include_computed=False)
  121. return format.format_state(initial_state)
  122. def _compile_client_storage_field(
  123. field: ModelField,
  124. ) -> tuple[Type[Cookie] | Type[LocalStorage] | None, dict[str, Any] | None]:
  125. """Compile the given cookie or local_storage field.
  126. Args:
  127. field: The possible cookie field to compile.
  128. Returns:
  129. A dictionary of the compiled cookie or None if the field is not cookie-like.
  130. """
  131. for field_type in (Cookie, LocalStorage):
  132. if isinstance(field.default, field_type):
  133. cs_obj = field.default
  134. elif isinstance(field.type_, type) and issubclass(field.type_, field_type):
  135. cs_obj = field.type_()
  136. else:
  137. continue
  138. return field_type, cs_obj.options()
  139. return None, None
  140. def _compile_client_storage_recursive(
  141. state: Type[BaseState],
  142. ) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
  143. """Compile the client-side storage for the given state recursively.
  144. Args:
  145. state: The app state object.
  146. Returns:
  147. A tuple of the compiled client-side storage info:
  148. (
  149. cookies: dict[str, dict],
  150. local_storage: dict[str, dict[str, str]]
  151. )
  152. """
  153. cookies = {}
  154. local_storage = {}
  155. state_name = state.get_full_name()
  156. for name, field in state.__fields__.items():
  157. if name in state.inherited_vars:
  158. # only include vars defined in this state
  159. continue
  160. state_key = f"{state_name}.{name}"
  161. field_type, options = _compile_client_storage_field(field)
  162. if field_type is Cookie:
  163. cookies[state_key] = options
  164. elif field_type is LocalStorage:
  165. local_storage[state_key] = options
  166. else:
  167. continue
  168. for substate in state.get_substates():
  169. substate_cookies, substate_local_storage = _compile_client_storage_recursive(
  170. substate
  171. )
  172. cookies.update(substate_cookies)
  173. local_storage.update(substate_local_storage)
  174. return cookies, local_storage
  175. def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
  176. """Compile the client-side storage for the given state.
  177. Args:
  178. state: The app state object.
  179. Returns:
  180. A dictionary of the compiled client-side storage info.
  181. """
  182. cookies, local_storage = _compile_client_storage_recursive(state)
  183. return {
  184. constants.COOKIES: cookies,
  185. constants.LOCAL_STORAGE: local_storage,
  186. }
  187. def compile_custom_component(
  188. component: CustomComponent,
  189. ) -> tuple[dict, ParsedImportDict]:
  190. """Compile a custom component.
  191. Args:
  192. component: The custom component to compile.
  193. Returns:
  194. A tuple of the compiled component and the imports required by the component.
  195. """
  196. # Render the component.
  197. render = component.get_component(component)
  198. # Get the imports.
  199. imports: ParsedImportDict = {
  200. lib: fields
  201. for lib, fields in render._get_all_imports().items()
  202. if lib != component.library
  203. }
  204. # Concatenate the props.
  205. props = [prop._var_name for prop in component.get_prop_vars()]
  206. # Compile the component.
  207. return (
  208. {
  209. "name": component.tag,
  210. "props": props,
  211. "render": render.render(),
  212. "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
  213. "custom_code": render._get_all_custom_code(),
  214. },
  215. imports,
  216. )
  217. def create_document_root(
  218. head_components: list[Component] | None = None,
  219. html_lang: Optional[str] = None,
  220. html_custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
  221. ) -> Component:
  222. """Create the document root.
  223. Args:
  224. head_components: The components to add to the head.
  225. html_lang: The language of the document, will be added to the html root element.
  226. html_custom_attrs: custom attributes added to the html root element.
  227. Returns:
  228. The document root.
  229. """
  230. head_components = head_components or []
  231. return Html.create(
  232. DocumentHead.create(*head_components),
  233. Body.create(
  234. Main.create(),
  235. NextScript.create(),
  236. ),
  237. lang=html_lang or "en",
  238. custom_attrs=html_custom_attrs or {},
  239. )
  240. def create_theme(style: ComponentStyle) -> dict:
  241. """Create the base style for the app.
  242. Args:
  243. style: The style dict for the app.
  244. Returns:
  245. The base style for the app.
  246. """
  247. # Get the global style from the style dict.
  248. style_rules = Style({k: v for k, v in style.items() if not isinstance(k, Callable)})
  249. root_style = {
  250. # Root styles.
  251. ":root": Style(
  252. {f"*{k}": v for k, v in style_rules.items() if k.startswith(":")}
  253. ),
  254. # Body styles.
  255. "body": Style(
  256. {k: v for k, v in style_rules.items() if not k.startswith(":")},
  257. ),
  258. }
  259. # Return the theme.
  260. return {"styles": {"global": root_style}}
  261. def get_page_path(path: str) -> str:
  262. """Get the path of the compiled JS file for the given page.
  263. Args:
  264. path: The path of the page.
  265. Returns:
  266. The path of the compiled JS file.
  267. """
  268. return os.path.join(constants.Dirs.WEB_PAGES, path + constants.Ext.JS)
  269. def get_theme_path() -> str:
  270. """Get the path of the base theme style.
  271. Returns:
  272. The path of the theme style.
  273. """
  274. return os.path.join(
  275. constants.Dirs.WEB_UTILS, constants.PageNames.THEME + constants.Ext.JS
  276. )
  277. def get_root_stylesheet_path() -> str:
  278. """Get the path of the app root file.
  279. Returns:
  280. The path of the app root file.
  281. """
  282. return os.path.join(
  283. constants.STYLES_DIR, constants.PageNames.STYLESHEET_ROOT + constants.Ext.CSS
  284. )
  285. def get_context_path() -> str:
  286. """Get the path of the context / initial state file.
  287. Returns:
  288. The path of the context module.
  289. """
  290. return os.path.join(
  291. constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS
  292. )
  293. def get_components_path() -> str:
  294. """Get the path of the compiled components.
  295. Returns:
  296. The path of the compiled components.
  297. """
  298. return os.path.join(constants.Dirs.WEB_UTILS, "components" + constants.Ext.JS)
  299. def get_stateful_components_path() -> str:
  300. """Get the path of the compiled stateful components.
  301. Returns:
  302. The path of the compiled stateful components.
  303. """
  304. return os.path.join(
  305. constants.Dirs.WEB_UTILS,
  306. constants.PageNames.STATEFUL_COMPONENTS + constants.Ext.JS,
  307. )
  308. def add_meta(
  309. page: Component,
  310. title: str,
  311. image: str,
  312. meta: list[dict],
  313. description: str | None = None,
  314. ) -> Component:
  315. """Add metadata to a page.
  316. Args:
  317. page: The component for the page.
  318. title: The title of the page.
  319. image: The image for the page.
  320. meta: The metadata list.
  321. description: The description of the page.
  322. Returns:
  323. The component with the metadata added.
  324. """
  325. meta_tags = [Meta.create(**item) for item in meta]
  326. children: list[Any] = [
  327. Title.create(title),
  328. ]
  329. if description:
  330. children.append(Description.create(content=description))
  331. children.append(Image.create(content=image))
  332. page.children.append(
  333. Head.create(
  334. *children,
  335. *meta_tags,
  336. )
  337. )
  338. return page
  339. def write_page(path: str, code: str):
  340. """Write the given code to the given path.
  341. Args:
  342. path: The path to write the code to.
  343. code: The code to write.
  344. """
  345. path_ops.mkdir(os.path.dirname(path))
  346. with open(path, "w", encoding="utf-8") as f:
  347. f.write(code)
  348. def empty_dir(path: str, keep_files: list[str] | None = None):
  349. """Remove all files and folders in a directory except for the keep_files.
  350. Args:
  351. path: The path to the directory that will be emptied
  352. keep_files: List of filenames or foldernames that will not be deleted.
  353. """
  354. # If the directory does not exist, return.
  355. if not os.path.exists(path):
  356. return
  357. # Remove all files and folders in the directory.
  358. keep_files = keep_files or []
  359. directory_contents = os.listdir(path)
  360. for element in directory_contents:
  361. if element not in keep_files:
  362. path_ops.rm(os.path.join(path, element))
  363. def is_valid_url(url) -> bool:
  364. """Check if a url is valid.
  365. Args:
  366. url: The Url to check.
  367. Returns:
  368. Whether url is valid.
  369. """
  370. result = urlparse(url)
  371. return all([result.scheme, result.netloc])