test_compiler.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import importlib.util
  2. import os
  3. from pathlib import Path
  4. import pytest
  5. from reflex import constants
  6. from reflex.compiler import compiler, utils
  7. from reflex.components.base import document
  8. from reflex.constants.compiler import PageNames
  9. from reflex.utils.imports import ImportVar, ParsedImportDict
  10. from reflex.vars.base import Var
  11. from reflex.vars.sequence import LiteralStringVar
  12. @pytest.mark.parametrize(
  13. "fields,test_default,test_rest",
  14. [
  15. (
  16. [ImportVar(tag="axios", is_default=True)],
  17. "axios",
  18. [],
  19. ),
  20. (
  21. [ImportVar(tag="foo"), ImportVar(tag="bar")],
  22. "",
  23. ["bar", "foo"],
  24. ),
  25. (
  26. [
  27. ImportVar(tag="axios", is_default=True),
  28. ImportVar(tag="foo"),
  29. ImportVar(tag="bar"),
  30. ],
  31. "axios",
  32. ["bar", "foo"],
  33. ),
  34. ],
  35. )
  36. def test_compile_import_statement(
  37. fields: list[ImportVar], test_default: str, test_rest: str
  38. ):
  39. """Test the compile_import_statement function.
  40. Args:
  41. fields: The fields to import.
  42. test_default: The expected output of default library.
  43. test_rest: The expected output rest libraries.
  44. """
  45. default, rest = utils.compile_import_statement(fields)
  46. assert default == test_default
  47. assert sorted(rest) == test_rest
  48. @pytest.mark.parametrize(
  49. "import_dict,test_dicts",
  50. [
  51. ({}, []),
  52. (
  53. {"axios": [ImportVar(tag="axios", is_default=True)]},
  54. [{"lib": "axios", "default": "axios", "rest": []}],
  55. ),
  56. (
  57. {"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]},
  58. [{"lib": "axios", "default": "", "rest": ["bar", "foo"]}],
  59. ),
  60. (
  61. {
  62. "axios": [
  63. ImportVar(tag="axios", is_default=True),
  64. ImportVar(tag="foo"),
  65. ImportVar(tag="bar"),
  66. ],
  67. "react": [ImportVar(tag="react", is_default=True)],
  68. },
  69. [
  70. {"lib": "axios", "default": "axios", "rest": ["bar", "foo"]},
  71. {"lib": "react", "default": "react", "rest": []},
  72. ],
  73. ),
  74. (
  75. {"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]},
  76. [
  77. {"lib": "lib1.js", "default": "", "rest": []},
  78. {"lib": "lib2.js", "default": "", "rest": []},
  79. ],
  80. ),
  81. (
  82. {
  83. "": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")],
  84. "axios": [ImportVar(tag="axios", is_default=True)],
  85. },
  86. [
  87. {"lib": "lib1.js", "default": "", "rest": []},
  88. {"lib": "lib2.js", "default": "", "rest": []},
  89. {"lib": "axios", "default": "axios", "rest": []},
  90. ],
  91. ),
  92. ],
  93. )
  94. def test_compile_imports(import_dict: ParsedImportDict, test_dicts: list[dict]):
  95. """Test the compile_imports function.
  96. Args:
  97. import_dict: The import dictionary.
  98. test_dicts: The expected output.
  99. """
  100. imports = utils.compile_imports(import_dict)
  101. for import_dict, test_dict in zip(imports, test_dicts, strict=True):
  102. assert import_dict["lib"] == test_dict["lib"]
  103. assert import_dict["default"] == test_dict["default"]
  104. assert (
  105. sorted(
  106. import_dict["rest"],
  107. key=lambda i: i if isinstance(i, str) else (i.tag or ""),
  108. )
  109. == test_dict["rest"]
  110. )
  111. def test_compile_stylesheets(tmp_path: Path, mocker):
  112. """Test that stylesheets compile correctly.
  113. Args:
  114. tmp_path: The test directory.
  115. mocker: Pytest mocker object.
  116. """
  117. project = tmp_path / "test_project"
  118. project.mkdir()
  119. assets_dir = project / "assets"
  120. assets_dir.mkdir()
  121. (assets_dir / "style.css").write_text(
  122. "button.rt-Button {\n\tborder-radius:unset !important;\n}"
  123. )
  124. mocker.patch("reflex.compiler.compiler.Path.cwd", return_value=project)
  125. mocker.patch(
  126. "reflex.compiler.compiler.get_web_dir",
  127. return_value=project / constants.Dirs.WEB,
  128. )
  129. mocker.patch(
  130. "reflex.compiler.utils.get_web_dir", return_value=project / constants.Dirs.WEB
  131. )
  132. stylesheets = [
  133. "https://fonts.googleapis.com/css?family=Sofia&effect=neon|outline|emboss|shadow-multiple",
  134. "https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css",
  135. "/style.css",
  136. "https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap-theme.min.css",
  137. ]
  138. assert compiler.compile_root_stylesheet(stylesheets) == (
  139. str(
  140. project
  141. / constants.Dirs.WEB
  142. / "styles"
  143. / (PageNames.STYLESHEET_ROOT + ".css")
  144. ),
  145. "@import url('./tailwind.css'); \n"
  146. "@import url('https://fonts.googleapis.com/css?family=Sofia&effect=neon|outline|emboss|shadow-multiple'); \n"
  147. "@import url('https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css'); \n"
  148. "@import url('https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap-theme.min.css'); \n"
  149. "@import url('./style.css'); \n",
  150. )
  151. assert (project / constants.Dirs.WEB / "styles" / "style.css").read_text() == (
  152. assets_dir / "style.css"
  153. ).read_text()
  154. def test_compile_stylesheets_scss_sass(tmp_path: Path, mocker):
  155. if importlib.util.find_spec("sass") is None:
  156. pytest.skip(
  157. 'The `libsass` package is required to compile sass/scss stylesheet files. Run `pip install "libsass>=0.23.0"`.'
  158. )
  159. if os.name == "nt":
  160. pytest.skip("Skipping test on Windows")
  161. project = tmp_path / "test_project"
  162. project.mkdir()
  163. assets_dir = project / "assets"
  164. assets_dir.mkdir()
  165. assets_preprocess_dir = assets_dir / "preprocess"
  166. assets_preprocess_dir.mkdir()
  167. (assets_dir / "style.css").write_text(
  168. "button.rt-Button {\n\tborder-radius:unset !important;\n}"
  169. )
  170. (assets_preprocess_dir / "styles_a.sass").write_text(
  171. "button.rt-Button\n\tborder-radius:unset !important"
  172. )
  173. (assets_preprocess_dir / "styles_b.scss").write_text(
  174. "button.rt-Button {\n\tborder-radius:unset !important;\n}"
  175. )
  176. mocker.patch("reflex.compiler.compiler.Path.cwd", return_value=project)
  177. mocker.patch(
  178. "reflex.compiler.compiler.get_web_dir",
  179. return_value=project / constants.Dirs.WEB,
  180. )
  181. mocker.patch(
  182. "reflex.compiler.utils.get_web_dir", return_value=project / constants.Dirs.WEB
  183. )
  184. stylesheets = [
  185. "/style.css",
  186. "/preprocess/styles_a.sass",
  187. "/preprocess/styles_b.scss",
  188. ]
  189. assert compiler.compile_root_stylesheet(stylesheets) == (
  190. str(
  191. project
  192. / constants.Dirs.WEB
  193. / "styles"
  194. / (PageNames.STYLESHEET_ROOT + ".css")
  195. ),
  196. "@import url('./tailwind.css'); \n"
  197. "@import url('./style.css'); \n"
  198. f"@import url('./{Path('preprocess') / Path('styles_a.css')!s}'); \n"
  199. f"@import url('./{Path('preprocess') / Path('styles_b.css')!s}'); \n",
  200. )
  201. stylesheets = [
  202. "/style.css",
  203. "/preprocess", # this is a folder containing "styles_a.sass" and "styles_b.scss"
  204. ]
  205. assert compiler.compile_root_stylesheet(stylesheets) == (
  206. str(
  207. project
  208. / constants.Dirs.WEB
  209. / "styles"
  210. / (PageNames.STYLESHEET_ROOT + ".css")
  211. ),
  212. "@import url('./tailwind.css'); \n"
  213. "@import url('./style.css'); \n"
  214. f"@import url('./{Path('preprocess') / Path('styles_a.css')!s}'); \n"
  215. f"@import url('./{Path('preprocess') / Path('styles_b.css')!s}'); \n",
  216. )
  217. assert (project / constants.Dirs.WEB / "styles" / "style.css").read_text() == (
  218. assets_dir / "style.css"
  219. ).read_text()
  220. expected_result = "button.rt-Button{border-radius:unset !important}\n"
  221. assert (
  222. project / constants.Dirs.WEB / "styles" / "preprocess" / "styles_a.css"
  223. ).read_text() == expected_result
  224. assert (
  225. project / constants.Dirs.WEB / "styles" / "preprocess" / "styles_b.css"
  226. ).read_text() == expected_result
  227. def test_compile_stylesheets_exclude_tailwind(tmp_path, mocker):
  228. """Test that Tailwind is excluded if tailwind config is explicitly set to None.
  229. Args:
  230. tmp_path: The test directory.
  231. mocker: Pytest mocker object.
  232. """
  233. project = tmp_path / "test_project"
  234. project.mkdir()
  235. assets_dir = project / "assets"
  236. assets_dir.mkdir()
  237. mock = mocker.Mock()
  238. mocker.patch.object(mock, "tailwind", None)
  239. mocker.patch("reflex.compiler.compiler.get_config", return_value=mock)
  240. (assets_dir / "style.css").touch()
  241. mocker.patch("reflex.compiler.compiler.Path.cwd", return_value=project)
  242. stylesheets = [
  243. "/style.css",
  244. ]
  245. assert compiler.compile_root_stylesheet(stylesheets) == (
  246. str(Path(".web") / "styles" / (PageNames.STYLESHEET_ROOT + ".css")),
  247. "@import url('./style.css'); \n",
  248. )
  249. def test_compile_nonexistent_stylesheet(tmp_path, mocker):
  250. """Test that an error is thrown for non-existent stylesheets.
  251. Args:
  252. tmp_path: The test directory.
  253. mocker: Pytest mocker object.
  254. """
  255. project = tmp_path / "test_project"
  256. project.mkdir()
  257. assets_dir = project / "assets"
  258. assets_dir.mkdir()
  259. mocker.patch("reflex.compiler.compiler.Path.cwd", return_value=project)
  260. stylesheets = ["/style.css"]
  261. with pytest.raises(FileNotFoundError):
  262. compiler.compile_root_stylesheet(stylesheets)
  263. def test_create_document_root():
  264. """Test that the document root is created correctly."""
  265. # Test with no components.
  266. root = utils.create_document_root()
  267. root.render()
  268. assert isinstance(root, utils.Html)
  269. assert isinstance(root.children[0], utils.Head)
  270. # Default language.
  271. lang = root.lang # pyright: ignore [reportAttributeAccessIssue]
  272. assert isinstance(lang, LiteralStringVar)
  273. assert lang.equals(Var.create("en"))
  274. # No children in head.
  275. assert len(root.children[0].children) == 4
  276. assert isinstance(root.children[0].children[0], utils.Meta)
  277. char_set = root.children[0].children[0].char_set # pyright: ignore [reportAttributeAccessIssue]
  278. assert isinstance(char_set, LiteralStringVar)
  279. assert char_set.equals(Var.create("utf-8"))
  280. assert isinstance(root.children[0].children[1], utils.Meta)
  281. name = root.children[0].children[1].name # pyright: ignore [reportAttributeAccessIssue]
  282. assert isinstance(name, LiteralStringVar)
  283. assert name.equals(Var.create("viewport"))
  284. assert isinstance(root.children[0].children[2], document.Meta)
  285. assert isinstance(root.children[0].children[3], document.Links)
  286. # Test with components.
  287. comps = [
  288. utils.Scripts.create(src="foo.js"),
  289. utils.Scripts.create(src="bar.js"),
  290. ]
  291. root = utils.create_document_root(
  292. head_components=comps,
  293. html_lang="rx",
  294. html_custom_attrs={"project": "reflex"},
  295. )
  296. # Two children in head.
  297. assert isinstance(root, utils.Html)
  298. assert len(root.children[0].children) == 2
  299. lang = root.lang # pyright: ignore [reportAttributeAccessIssue]
  300. assert isinstance(lang, LiteralStringVar)
  301. assert lang.equals(Var.create("rx"))
  302. assert isinstance(root.custom_attrs, dict)
  303. assert root.custom_attrs == {"project": "reflex"}