test_compiler.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. import importlib.util
  2. import os
  3. from pathlib import Path
  4. import pytest
  5. from reflex import constants
  6. from reflex.compiler import compiler, utils
  7. from reflex.constants.compiler import PageNames
  8. from reflex.utils.imports import ImportVar, ParsedImportDict
  9. @pytest.mark.parametrize(
  10. "fields,test_default,test_rest",
  11. [
  12. (
  13. [ImportVar(tag="axios", is_default=True)],
  14. "axios",
  15. [],
  16. ),
  17. (
  18. [ImportVar(tag="foo"), ImportVar(tag="bar")],
  19. "",
  20. ["bar", "foo"],
  21. ),
  22. (
  23. [
  24. ImportVar(tag="axios", is_default=True),
  25. ImportVar(tag="foo"),
  26. ImportVar(tag="bar"),
  27. ],
  28. "axios",
  29. ["bar", "foo"],
  30. ),
  31. ],
  32. )
  33. def test_compile_import_statement(
  34. fields: list[ImportVar], test_default: str, test_rest: str
  35. ):
  36. """Test the compile_import_statement function.
  37. Args:
  38. fields: The fields to import.
  39. test_default: The expected output of default library.
  40. test_rest: The expected output rest libraries.
  41. """
  42. default, rest = utils.compile_import_statement(fields)
  43. assert default == test_default
  44. assert sorted(rest) == test_rest
  45. @pytest.mark.parametrize(
  46. "import_dict,test_dicts",
  47. [
  48. ({}, []),
  49. (
  50. {"axios": [ImportVar(tag="axios", is_default=True)]},
  51. [{"lib": "axios", "default": "axios", "rest": []}],
  52. ),
  53. (
  54. {"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]},
  55. [{"lib": "axios", "default": "", "rest": ["bar", "foo"]}],
  56. ),
  57. (
  58. {
  59. "axios": [
  60. ImportVar(tag="axios", is_default=True),
  61. ImportVar(tag="foo"),
  62. ImportVar(tag="bar"),
  63. ],
  64. "react": [ImportVar(tag="react", is_default=True)],
  65. },
  66. [
  67. {"lib": "axios", "default": "axios", "rest": ["bar", "foo"]},
  68. {"lib": "react", "default": "react", "rest": []},
  69. ],
  70. ),
  71. (
  72. {"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]},
  73. [
  74. {"lib": "lib1.js", "default": "", "rest": []},
  75. {"lib": "lib2.js", "default": "", "rest": []},
  76. ],
  77. ),
  78. (
  79. {
  80. "": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")],
  81. "axios": [ImportVar(tag="axios", is_default=True)],
  82. },
  83. [
  84. {"lib": "lib1.js", "default": "", "rest": []},
  85. {"lib": "lib2.js", "default": "", "rest": []},
  86. {"lib": "axios", "default": "axios", "rest": []},
  87. ],
  88. ),
  89. ],
  90. )
  91. def test_compile_imports(import_dict: ParsedImportDict, test_dicts: list[dict]):
  92. """Test the compile_imports function.
  93. Args:
  94. import_dict: The import dictionary.
  95. test_dicts: The expected output.
  96. """
  97. imports = utils.compile_imports(import_dict)
  98. for import_dict, test_dict in zip(imports, test_dicts, strict=True):
  99. assert import_dict["lib"] == test_dict["lib"]
  100. assert import_dict["default"] == test_dict["default"]
  101. assert (
  102. sorted(
  103. import_dict["rest"],
  104. key=lambda i: i if isinstance(i, str) else (i.tag or ""),
  105. )
  106. == test_dict["rest"]
  107. )
  108. def test_compile_stylesheets(tmp_path: Path, mocker):
  109. """Test that stylesheets compile correctly.
  110. Args:
  111. tmp_path: The test directory.
  112. mocker: Pytest mocker object.
  113. """
  114. project = tmp_path / "test_project"
  115. project.mkdir()
  116. assets_dir = project / "assets"
  117. assets_dir.mkdir()
  118. (assets_dir / "style.css").write_text(
  119. "button.rt-Button {\n\tborder-radius:unset !important;\n}"
  120. )
  121. mocker.patch("reflex.compiler.compiler.Path.cwd", return_value=project)
  122. mocker.patch(
  123. "reflex.compiler.compiler.get_web_dir",
  124. return_value=project / constants.Dirs.WEB,
  125. )
  126. mocker.patch(
  127. "reflex.compiler.utils.get_web_dir", return_value=project / constants.Dirs.WEB
  128. )
  129. stylesheets = [
  130. "https://fonts.googleapis.com/css?family=Sofia&effect=neon|outline|emboss|shadow-multiple",
  131. "https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css",
  132. "/style.css",
  133. "https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap-theme.min.css",
  134. ]
  135. assert compiler.compile_root_stylesheet(stylesheets) == (
  136. str(
  137. project
  138. / constants.Dirs.WEB
  139. / "styles"
  140. / (PageNames.STYLESHEET_ROOT + ".css")
  141. ),
  142. "@import url('./tailwind.css'); \n"
  143. "@import url('https://fonts.googleapis.com/css?family=Sofia&effect=neon|outline|emboss|shadow-multiple'); \n"
  144. "@import url('https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css'); \n"
  145. "@import url('https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap-theme.min.css'); \n"
  146. "@import url('./style.css'); \n",
  147. )
  148. assert (project / constants.Dirs.WEB / "styles" / "style.css").read_text() == (
  149. assets_dir / "style.css"
  150. ).read_text()
  151. def test_compile_stylesheets_scss_sass(tmp_path: Path, mocker):
  152. if importlib.util.find_spec("sass") is None:
  153. pytest.skip(
  154. 'The `libsass` package is required to compile sass/scss stylesheet files. Run `pip install "libsass>=0.23.0"`.'
  155. )
  156. if os.name == "nt":
  157. pytest.skip("Skipping test on Windows")
  158. project = tmp_path / "test_project"
  159. project.mkdir()
  160. assets_dir = project / "assets"
  161. assets_dir.mkdir()
  162. assets_preprocess_dir = assets_dir / "preprocess"
  163. assets_preprocess_dir.mkdir()
  164. (assets_dir / "style.css").write_text(
  165. "button.rt-Button {\n\tborder-radius:unset !important;\n}"
  166. )
  167. (assets_preprocess_dir / "styles_a.sass").write_text(
  168. "button.rt-Button\n\tborder-radius:unset !important"
  169. )
  170. (assets_preprocess_dir / "styles_b.scss").write_text(
  171. "button.rt-Button {\n\tborder-radius:unset !important;\n}"
  172. )
  173. mocker.patch("reflex.compiler.compiler.Path.cwd", return_value=project)
  174. mocker.patch(
  175. "reflex.compiler.compiler.get_web_dir",
  176. return_value=project / constants.Dirs.WEB,
  177. )
  178. mocker.patch(
  179. "reflex.compiler.utils.get_web_dir", return_value=project / constants.Dirs.WEB
  180. )
  181. stylesheets = [
  182. "/style.css",
  183. "/preprocess/styles_a.sass",
  184. "/preprocess/styles_b.scss",
  185. ]
  186. assert compiler.compile_root_stylesheet(stylesheets) == (
  187. str(
  188. project
  189. / constants.Dirs.WEB
  190. / "styles"
  191. / (PageNames.STYLESHEET_ROOT + ".css")
  192. ),
  193. "@import url('./tailwind.css'); \n"
  194. "@import url('./style.css'); \n"
  195. f"@import url('./{Path('preprocess') / Path('styles_a.css')!s}'); \n"
  196. f"@import url('./{Path('preprocess') / Path('styles_b.css')!s}'); \n",
  197. )
  198. stylesheets = [
  199. "/style.css",
  200. "/preprocess", # this is a folder containing "styles_a.sass" and "styles_b.scss"
  201. ]
  202. assert compiler.compile_root_stylesheet(stylesheets) == (
  203. str(
  204. project
  205. / constants.Dirs.WEB
  206. / "styles"
  207. / (PageNames.STYLESHEET_ROOT + ".css")
  208. ),
  209. "@import url('./tailwind.css'); \n"
  210. "@import url('./style.css'); \n"
  211. f"@import url('./{Path('preprocess') / Path('styles_a.css')!s}'); \n"
  212. f"@import url('./{Path('preprocess') / Path('styles_b.css')!s}'); \n",
  213. )
  214. assert (project / constants.Dirs.WEB / "styles" / "style.css").read_text() == (
  215. assets_dir / "style.css"
  216. ).read_text()
  217. expected_result = "button.rt-Button{border-radius:unset !important}\n"
  218. assert (
  219. project / constants.Dirs.WEB / "styles" / "preprocess" / "styles_a.css"
  220. ).read_text() == expected_result
  221. assert (
  222. project / constants.Dirs.WEB / "styles" / "preprocess" / "styles_b.css"
  223. ).read_text() == expected_result
  224. def test_compile_stylesheets_exclude_tailwind(tmp_path, mocker):
  225. """Test that Tailwind is excluded if tailwind config is explicitly set to None.
  226. Args:
  227. tmp_path: The test directory.
  228. mocker: Pytest mocker object.
  229. """
  230. project = tmp_path / "test_project"
  231. project.mkdir()
  232. assets_dir = project / "assets"
  233. assets_dir.mkdir()
  234. mock = mocker.Mock()
  235. mocker.patch.object(mock, "tailwind", None)
  236. mocker.patch("reflex.compiler.compiler.get_config", return_value=mock)
  237. (assets_dir / "style.css").touch()
  238. mocker.patch("reflex.compiler.compiler.Path.cwd", return_value=project)
  239. stylesheets = [
  240. "/style.css",
  241. ]
  242. assert compiler.compile_root_stylesheet(stylesheets) == (
  243. str(Path(".web") / "styles" / (PageNames.STYLESHEET_ROOT + ".css")),
  244. "@import url('./style.css'); \n",
  245. )
  246. def test_compile_nonexistent_stylesheet(tmp_path, mocker):
  247. """Test that an error is thrown for non-existent stylesheets.
  248. Args:
  249. tmp_path: The test directory.
  250. mocker: Pytest mocker object.
  251. """
  252. project = tmp_path / "test_project"
  253. project.mkdir()
  254. assets_dir = project / "assets"
  255. assets_dir.mkdir()
  256. mocker.patch("reflex.compiler.compiler.Path.cwd", return_value=project)
  257. stylesheets = ["/style.css"]
  258. with pytest.raises(FileNotFoundError):
  259. compiler.compile_root_stylesheet(stylesheets)
  260. def test_create_document_root():
  261. """Test that the document root is created correctly."""
  262. # Test with no components.
  263. root = utils.create_document_root()
  264. root.render()
  265. assert isinstance(root, utils.Html)
  266. assert isinstance(root.children[0], utils.DocumentHead)
  267. # Default language.
  268. assert root.lang == "en" # pyright: ignore [reportAttributeAccessIssue]
  269. # No children in head.
  270. assert len(root.children[0].children) == 0
  271. # Test with components.
  272. comps = [
  273. utils.NextScript.create(src="foo.js"),
  274. utils.NextScript.create(src="bar.js"),
  275. ]
  276. root = utils.create_document_root(
  277. head_components=comps,
  278. html_lang="rx",
  279. html_custom_attrs={"project": "reflex"},
  280. )
  281. # Two children in head.
  282. assert isinstance(root, utils.Html)
  283. assert len(root.children[0].children) == 2
  284. assert root.lang == "rx" # pyright: ignore [reportAttributeAccessIssue]
  285. assert isinstance(root.custom_attrs, dict)
  286. assert root.custom_attrs == {"project": "reflex"}