test_prerequisites.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import json
  2. import re
  3. import shutil
  4. import tempfile
  5. from pathlib import Path
  6. from unittest.mock import Mock, mock_open
  7. import pytest
  8. from typer.testing import CliRunner
  9. from reflex import constants
  10. from reflex.config import Config
  11. from reflex.reflex import cli
  12. from reflex.testing import chdir
  13. from reflex.utils.prerequisites import (
  14. CpuInfo,
  15. _update_next_config,
  16. cached_procedure,
  17. get_cpu_info,
  18. initialize_requirements_txt,
  19. rename_imports_and_app_name,
  20. )
  21. runner = CliRunner()
  22. @pytest.mark.parametrize(
  23. "config, export, expected_output",
  24. [
  25. (
  26. Config(
  27. app_name="test",
  28. ),
  29. False,
  30. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
  31. ),
  32. (
  33. Config(
  34. app_name="test",
  35. static_page_generation_timeout=30,
  36. ),
  37. False,
  38. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 30};',
  39. ),
  40. (
  41. Config(
  42. app_name="test",
  43. next_compression=False,
  44. ),
  45. False,
  46. 'module.exports = {basePath: "", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60};',
  47. ),
  48. (
  49. Config(
  50. app_name="test",
  51. frontend_path="/test",
  52. ),
  53. False,
  54. 'module.exports = {basePath: "/test", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
  55. ),
  56. (
  57. Config(
  58. app_name="test",
  59. frontend_path="/test",
  60. next_compression=False,
  61. ),
  62. False,
  63. 'module.exports = {basePath: "/test", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60};',
  64. ),
  65. (
  66. Config(
  67. app_name="test",
  68. ),
  69. True,
  70. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, output: "export", distDir: "_static"};',
  71. ),
  72. ],
  73. )
  74. def test_update_next_config(config, export, expected_output):
  75. output = _update_next_config(config, export=export)
  76. assert output == expected_output
  77. @pytest.mark.parametrize(
  78. ("transpile_packages", "expected_transpile_packages"),
  79. (
  80. (
  81. ["foo", "@bar/baz"],
  82. ["@bar/baz", "foo"],
  83. ),
  84. (
  85. ["foo", "@bar/baz", "foo", "@bar/baz@3.2.1"],
  86. ["@bar/baz", "foo"],
  87. ),
  88. ),
  89. )
  90. def test_transpile_packages(transpile_packages, expected_transpile_packages):
  91. output = _update_next_config(
  92. Config(app_name="test"),
  93. transpile_packages=transpile_packages,
  94. )
  95. transpile_packages_match = re.search(r"transpilePackages: (\[.*?\])", output)
  96. transpile_packages_json = transpile_packages_match.group(1) # pyright: ignore [reportOptionalMemberAccess]
  97. actual_transpile_packages = sorted(json.loads(transpile_packages_json))
  98. assert actual_transpile_packages == expected_transpile_packages
  99. def test_initialize_requirements_txt_no_op(mocker):
  100. # File exists, reflex is included, do nothing
  101. mocker.patch("pathlib.Path.exists", return_value=True)
  102. mocker.patch(
  103. "charset_normalizer.from_path",
  104. return_value=Mock(best=lambda: Mock(encoding="utf-8")),
  105. )
  106. mock_fp_touch = mocker.patch("pathlib.Path.touch")
  107. open_mock = mock_open(read_data="reflex==0.6.7")
  108. mocker.patch("pathlib.Path.open", open_mock)
  109. initialize_requirements_txt()
  110. assert open_mock.call_count == 1
  111. assert open_mock.call_args.kwargs["encoding"] == "utf-8"
  112. assert open_mock().write.call_count == 0
  113. mock_fp_touch.assert_not_called()
  114. def test_initialize_requirements_txt_missing_reflex(mocker):
  115. # File exists, reflex is not included, add reflex
  116. mocker.patch("pathlib.Path.exists", return_value=True)
  117. mocker.patch(
  118. "charset_normalizer.from_path",
  119. return_value=Mock(best=lambda: Mock(encoding="utf-8")),
  120. )
  121. open_mock = mock_open(read_data="random-package=1.2.3")
  122. mocker.patch("pathlib.Path.open", open_mock)
  123. initialize_requirements_txt()
  124. # Currently open for read, then open for append
  125. assert open_mock.call_count == 2
  126. for call_args in open_mock.call_args_list:
  127. assert call_args.kwargs["encoding"] == "utf-8"
  128. assert (
  129. open_mock().write.call_args[0][0]
  130. == f"\n{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
  131. )
  132. def test_initialize_requirements_txt_not_exist(mocker):
  133. # File does not exist, create file with reflex
  134. mocker.patch("pathlib.Path.exists", return_value=False)
  135. open_mock = mock_open()
  136. mocker.patch("pathlib.Path.open", open_mock)
  137. initialize_requirements_txt()
  138. assert open_mock.call_count == 2
  139. # By default, use utf-8 encoding
  140. for call_args in open_mock.call_args_list:
  141. assert call_args.kwargs["encoding"] == "utf-8"
  142. assert open_mock().write.call_count == 1
  143. assert (
  144. open_mock().write.call_args[0][0]
  145. == f"{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
  146. )
  147. def test_requirements_txt_cannot_detect_encoding(mocker):
  148. mocker.patch("pathlib.Path.exists", return_value=True)
  149. mock_open = mocker.patch("builtins.open")
  150. mocker.patch(
  151. "charset_normalizer.from_path",
  152. return_value=Mock(best=lambda: None),
  153. )
  154. initialize_requirements_txt()
  155. mock_open.assert_not_called()
  156. def test_requirements_txt_other_encoding(mocker):
  157. mocker.patch("pathlib.Path.exists", return_value=True)
  158. mocker.patch(
  159. "charset_normalizer.from_path",
  160. return_value=Mock(best=lambda: Mock(encoding="utf-16")),
  161. )
  162. initialize_requirements_txt()
  163. open_mock = mock_open(read_data="random-package=1.2.3")
  164. mocker.patch("pathlib.Path.open", open_mock)
  165. initialize_requirements_txt()
  166. # Currently open for read, then open for append
  167. assert open_mock.call_count == 2
  168. for call_args in open_mock.call_args_list:
  169. assert call_args.kwargs["encoding"] == "utf-16"
  170. assert (
  171. open_mock().write.call_args[0][0]
  172. == f"\n{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
  173. )
  174. def test_cached_procedure():
  175. call_count = 0
  176. @cached_procedure(tempfile.mktemp(), payload_fn=lambda: "constant")
  177. def _function_with_no_args():
  178. nonlocal call_count
  179. call_count += 1
  180. _function_with_no_args()
  181. assert call_count == 1
  182. _function_with_no_args()
  183. assert call_count == 1
  184. call_count = 0
  185. @cached_procedure(
  186. cache_file=tempfile.mktemp(),
  187. payload_fn=lambda *args, **kwargs: f"{repr(args), repr(kwargs)}",
  188. )
  189. def _function_with_some_args(*args, **kwargs):
  190. nonlocal call_count
  191. call_count += 1
  192. _function_with_some_args(1, y=2)
  193. assert call_count == 1
  194. _function_with_some_args(1, y=2)
  195. assert call_count == 1
  196. _function_with_some_args(100, y=300)
  197. assert call_count == 2
  198. _function_with_some_args(100, y=300)
  199. assert call_count == 2
  200. call_count = 0
  201. @cached_procedure(
  202. cache_file=None, cache_file_fn=tempfile.mktemp, payload_fn=lambda: "constant"
  203. )
  204. def _function_with_no_args_fn():
  205. nonlocal call_count
  206. call_count += 1
  207. _function_with_no_args_fn()
  208. assert call_count == 1
  209. _function_with_no_args_fn()
  210. assert call_count == 2
  211. def test_get_cpu_info():
  212. cpu_info = get_cpu_info()
  213. assert cpu_info is not None
  214. assert isinstance(cpu_info, CpuInfo)
  215. assert cpu_info.model_name is not None
  216. for attr in ("manufacturer_id", "model_name", "address_width"):
  217. value = getattr(cpu_info, attr)
  218. assert value.strip() if attr != "address_width" else value
  219. @pytest.fixture
  220. def temp_directory():
  221. temp_dir = tempfile.mkdtemp()
  222. yield Path(temp_dir)
  223. shutil.rmtree(temp_dir)
  224. @pytest.mark.parametrize(
  225. "config_code,expected",
  226. [
  227. ("rx.Config(app_name='old_name')", 'rx.Config(app_name="new_name")'),
  228. ('rx.Config(app_name="old_name")', 'rx.Config(app_name="new_name")'),
  229. ("rx.Config('old_name')", 'rx.Config("new_name")'),
  230. ('rx.Config("old_name")', 'rx.Config("new_name")'),
  231. ],
  232. )
  233. def test_rename_imports_and_app_name(temp_directory, config_code, expected):
  234. file_path = temp_directory / "rxconfig.py"
  235. content = f"""
  236. config = {config_code}
  237. """
  238. file_path.write_text(content)
  239. rename_imports_and_app_name(file_path, "old_name", "new_name")
  240. updated_content = file_path.read_text()
  241. expected_content = f"""
  242. config = {expected}
  243. """
  244. assert updated_content == expected_content
  245. def test_regex_edge_cases(temp_directory):
  246. file_path = temp_directory / "example.py"
  247. content = """
  248. from old_name.module import something
  249. import old_name
  250. from old_name import something_else as alias
  251. from old_name
  252. """
  253. file_path.write_text(content)
  254. rename_imports_and_app_name(file_path, "old_name", "new_name")
  255. updated_content = file_path.read_text()
  256. expected_content = """
  257. from new_name.module import something
  258. import new_name
  259. from new_name import something_else as alias
  260. from new_name
  261. """
  262. assert updated_content == expected_content
  263. def test_cli_rename_command(temp_directory):
  264. foo_dir = temp_directory / "foo"
  265. foo_dir.mkdir()
  266. (foo_dir / "__init__").touch()
  267. (foo_dir / ".web").mkdir()
  268. (foo_dir / "assets").mkdir()
  269. (foo_dir / "foo").mkdir()
  270. (foo_dir / "foo" / "__init__.py").touch()
  271. (foo_dir / "rxconfig.py").touch()
  272. (foo_dir / "rxconfig.py").write_text(
  273. """
  274. import reflex as rx
  275. config = rx.Config(
  276. app_name="foo",
  277. )
  278. """
  279. )
  280. (foo_dir / "foo" / "components").mkdir()
  281. (foo_dir / "foo" / "components" / "__init__.py").touch()
  282. (foo_dir / "foo" / "components" / "base.py").touch()
  283. (foo_dir / "foo" / "components" / "views.py").touch()
  284. (foo_dir / "foo" / "components" / "base.py").write_text(
  285. """
  286. import reflex as rx
  287. from foo.components import views
  288. from foo.components.views import *
  289. from .base import *
  290. def random_component():
  291. return rx.fragment()
  292. """
  293. )
  294. (foo_dir / "foo" / "foo.py").touch()
  295. (foo_dir / "foo" / "foo.py").write_text(
  296. """
  297. import reflex as rx
  298. import foo.components.base
  299. from foo.components.base import random_component
  300. class State(rx.State):
  301. pass
  302. def index():
  303. return rx.text("Hello, World!")
  304. app = rx.App()
  305. app.add_page(index)
  306. """
  307. )
  308. with chdir(temp_directory / "foo"):
  309. result = runner.invoke(cli, ["rename", "bar"])
  310. assert result.exit_code == 0
  311. assert (foo_dir / "rxconfig.py").read_text() == (
  312. """
  313. import reflex as rx
  314. config = rx.Config(
  315. app_name="bar",
  316. )
  317. """
  318. )
  319. assert (foo_dir / "bar").exists()
  320. assert not (foo_dir / "foo").exists()
  321. assert (foo_dir / "bar" / "components" / "base.py").read_text() == (
  322. """
  323. import reflex as rx
  324. from bar.components import views
  325. from bar.components.views import *
  326. from .base import *
  327. def random_component():
  328. return rx.fragment()
  329. """
  330. )
  331. assert (foo_dir / "bar" / "bar.py").exists()
  332. assert not (foo_dir / "bar" / "foo.py").exists()
  333. assert (foo_dir / "bar" / "bar.py").read_text() == (
  334. """
  335. import reflex as rx
  336. import bar.components.base
  337. from bar.components.base import random_component
  338. class State(rx.State):
  339. pass
  340. def index():
  341. return rx.text("Hello, World!")
  342. app = rx.App()
  343. app.add_page(index)
  344. """
  345. )