test_prerequisites.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. import json
  2. import re
  3. import shutil
  4. import tempfile
  5. from pathlib import Path
  6. from unittest.mock import Mock, mock_open
  7. import pytest
  8. from typer.testing import CliRunner
  9. from reflex import constants
  10. from reflex.config import Config
  11. from reflex.reflex import cli
  12. from reflex.testing import chdir
  13. from reflex.utils.prerequisites import (
  14. CpuInfo,
  15. _update_next_config,
  16. cached_procedure,
  17. get_cpu_info,
  18. initialize_requirements_txt,
  19. rename_imports_and_app_name,
  20. )
  21. runner = CliRunner()
  22. @pytest.mark.parametrize(
  23. "config, export, expected_output",
  24. [
  25. (
  26. Config(
  27. app_name="test",
  28. ),
  29. False,
  30. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  31. ),
  32. (
  33. Config(
  34. app_name="test",
  35. static_page_generation_timeout=30,
  36. ),
  37. False,
  38. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 30, devIndicators: false};',
  39. ),
  40. (
  41. Config(
  42. app_name="test",
  43. next_compression=False,
  44. ),
  45. False,
  46. 'module.exports = {basePath: "", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  47. ),
  48. (
  49. Config(
  50. app_name="test",
  51. frontend_path="/test",
  52. ),
  53. False,
  54. 'module.exports = {basePath: "/test", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  55. ),
  56. (
  57. Config(
  58. app_name="test",
  59. frontend_path="/test",
  60. next_compression=False,
  61. ),
  62. False,
  63. 'module.exports = {basePath: "/test", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  64. ),
  65. (
  66. Config(
  67. app_name="test",
  68. ),
  69. True,
  70. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false, output: "export", distDir: "_static"};',
  71. ),
  72. (
  73. Config(
  74. app_name="test",
  75. next_dev_indicators=True,
  76. ),
  77. True,
  78. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: true, output: "export", distDir: "_static"};',
  79. ),
  80. ],
  81. )
  82. def test_update_next_config(config, export, expected_output):
  83. output = _update_next_config(config, export=export)
  84. assert output == expected_output
  85. @pytest.mark.parametrize(
  86. ("transpile_packages", "expected_transpile_packages"),
  87. (
  88. (
  89. ["foo", "@bar/baz"],
  90. ["@bar/baz", "foo"],
  91. ),
  92. (
  93. ["foo", "@bar/baz", "foo", "@bar/baz@3.2.1"],
  94. ["@bar/baz", "foo"],
  95. ),
  96. ),
  97. )
  98. def test_transpile_packages(transpile_packages, expected_transpile_packages):
  99. output = _update_next_config(
  100. Config(app_name="test"),
  101. transpile_packages=transpile_packages,
  102. )
  103. transpile_packages_match = re.search(r"transpilePackages: (\[.*?\])", output)
  104. transpile_packages_json = transpile_packages_match.group(1) # pyright: ignore [reportOptionalMemberAccess]
  105. actual_transpile_packages = sorted(json.loads(transpile_packages_json))
  106. assert actual_transpile_packages == expected_transpile_packages
  107. def test_initialize_requirements_txt_no_op(mocker):
  108. # File exists, reflex is included, do nothing
  109. mocker.patch("pathlib.Path.exists", return_value=True)
  110. mocker.patch(
  111. "charset_normalizer.from_path",
  112. return_value=Mock(best=lambda: Mock(encoding="utf-8")),
  113. )
  114. mock_fp_touch = mocker.patch("pathlib.Path.touch")
  115. open_mock = mock_open(read_data="reflex==0.6.7")
  116. mocker.patch("pathlib.Path.open", open_mock)
  117. initialize_requirements_txt()
  118. assert open_mock.call_count == 1
  119. assert open_mock.call_args.kwargs["encoding"] == "utf-8"
  120. assert open_mock().write.call_count == 0
  121. mock_fp_touch.assert_not_called()
  122. def test_initialize_requirements_txt_missing_reflex(mocker):
  123. # File exists, reflex is not included, add reflex
  124. mocker.patch("pathlib.Path.exists", return_value=True)
  125. mocker.patch(
  126. "charset_normalizer.from_path",
  127. return_value=Mock(best=lambda: Mock(encoding="utf-8")),
  128. )
  129. open_mock = mock_open(read_data="random-package=1.2.3")
  130. mocker.patch("pathlib.Path.open", open_mock)
  131. initialize_requirements_txt()
  132. # Currently open for read, then open for append
  133. assert open_mock.call_count == 2
  134. for call_args in open_mock.call_args_list:
  135. assert call_args.kwargs["encoding"] == "utf-8"
  136. assert (
  137. open_mock().write.call_args[0][0]
  138. == f"\n{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
  139. )
  140. def test_initialize_requirements_txt_not_exist(mocker):
  141. # File does not exist, create file with reflex
  142. mocker.patch("pathlib.Path.exists", return_value=False)
  143. open_mock = mock_open()
  144. mocker.patch("pathlib.Path.open", open_mock)
  145. initialize_requirements_txt()
  146. assert open_mock.call_count == 2
  147. # By default, use utf-8 encoding
  148. for call_args in open_mock.call_args_list:
  149. assert call_args.kwargs["encoding"] == "utf-8"
  150. assert open_mock().write.call_count == 1
  151. assert (
  152. open_mock().write.call_args[0][0]
  153. == f"{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
  154. )
  155. def test_requirements_txt_cannot_detect_encoding(mocker):
  156. mocker.patch("pathlib.Path.exists", return_value=True)
  157. mock_open = mocker.patch("builtins.open")
  158. mocker.patch(
  159. "charset_normalizer.from_path",
  160. return_value=Mock(best=lambda: None),
  161. )
  162. initialize_requirements_txt()
  163. mock_open.assert_not_called()
  164. def test_requirements_txt_other_encoding(mocker):
  165. mocker.patch("pathlib.Path.exists", return_value=True)
  166. mocker.patch(
  167. "charset_normalizer.from_path",
  168. return_value=Mock(best=lambda: Mock(encoding="utf-16")),
  169. )
  170. initialize_requirements_txt()
  171. open_mock = mock_open(read_data="random-package=1.2.3")
  172. mocker.patch("pathlib.Path.open", open_mock)
  173. initialize_requirements_txt()
  174. # Currently open for read, then open for append
  175. assert open_mock.call_count == 2
  176. for call_args in open_mock.call_args_list:
  177. assert call_args.kwargs["encoding"] == "utf-16"
  178. assert (
  179. open_mock().write.call_args[0][0]
  180. == f"\n{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
  181. )
  182. def test_cached_procedure():
  183. call_count = 0
  184. @cached_procedure(tempfile.mktemp(), payload_fn=lambda: "constant")
  185. def _function_with_no_args():
  186. nonlocal call_count
  187. call_count += 1
  188. _function_with_no_args()
  189. assert call_count == 1
  190. _function_with_no_args()
  191. assert call_count == 1
  192. call_count = 0
  193. @cached_procedure(
  194. cache_file=tempfile.mktemp(),
  195. payload_fn=lambda *args, **kwargs: f"{repr(args), repr(kwargs)}",
  196. )
  197. def _function_with_some_args(*args, **kwargs):
  198. nonlocal call_count
  199. call_count += 1
  200. _function_with_some_args(1, y=2)
  201. assert call_count == 1
  202. _function_with_some_args(1, y=2)
  203. assert call_count == 1
  204. _function_with_some_args(100, y=300)
  205. assert call_count == 2
  206. _function_with_some_args(100, y=300)
  207. assert call_count == 2
  208. call_count = 0
  209. @cached_procedure(
  210. cache_file=None, cache_file_fn=tempfile.mktemp, payload_fn=lambda: "constant"
  211. )
  212. def _function_with_no_args_fn():
  213. nonlocal call_count
  214. call_count += 1
  215. _function_with_no_args_fn()
  216. assert call_count == 1
  217. _function_with_no_args_fn()
  218. assert call_count == 2
  219. def test_get_cpu_info():
  220. cpu_info = get_cpu_info()
  221. assert cpu_info is not None
  222. assert isinstance(cpu_info, CpuInfo)
  223. assert cpu_info.model_name is not None
  224. for attr in ("manufacturer_id", "model_name", "address_width"):
  225. value = getattr(cpu_info, attr)
  226. assert value.strip() if attr != "address_width" else value
  227. @pytest.fixture
  228. def temp_directory():
  229. temp_dir = tempfile.mkdtemp()
  230. yield Path(temp_dir)
  231. shutil.rmtree(temp_dir)
  232. @pytest.mark.parametrize(
  233. "config_code,expected",
  234. [
  235. ("rx.Config(app_name='old_name')", 'rx.Config(app_name="new_name")'),
  236. ('rx.Config(app_name="old_name")', 'rx.Config(app_name="new_name")'),
  237. ("rx.Config('old_name')", 'rx.Config("new_name")'),
  238. ('rx.Config("old_name")', 'rx.Config("new_name")'),
  239. ],
  240. )
  241. def test_rename_imports_and_app_name(temp_directory, config_code, expected):
  242. file_path = temp_directory / "rxconfig.py"
  243. content = f"""
  244. config = {config_code}
  245. """
  246. file_path.write_text(content)
  247. rename_imports_and_app_name(file_path, "old_name", "new_name")
  248. updated_content = file_path.read_text()
  249. expected_content = f"""
  250. config = {expected}
  251. """
  252. assert updated_content == expected_content
  253. def test_regex_edge_cases(temp_directory):
  254. file_path = temp_directory / "example.py"
  255. content = """
  256. from old_name.module import something
  257. import old_name
  258. from old_name import something_else as alias
  259. from old_name
  260. """
  261. file_path.write_text(content)
  262. rename_imports_and_app_name(file_path, "old_name", "new_name")
  263. updated_content = file_path.read_text()
  264. expected_content = """
  265. from new_name.module import something
  266. import new_name
  267. from new_name import something_else as alias
  268. from new_name
  269. """
  270. assert updated_content == expected_content
  271. def test_cli_rename_command(temp_directory):
  272. foo_dir = temp_directory / "foo"
  273. foo_dir.mkdir()
  274. (foo_dir / "__init__").touch()
  275. (foo_dir / ".web").mkdir()
  276. (foo_dir / "assets").mkdir()
  277. (foo_dir / "foo").mkdir()
  278. (foo_dir / "foo" / "__init__.py").touch()
  279. (foo_dir / "rxconfig.py").touch()
  280. (foo_dir / "rxconfig.py").write_text(
  281. """
  282. import reflex as rx
  283. config = rx.Config(
  284. app_name="foo",
  285. )
  286. """
  287. )
  288. (foo_dir / "foo" / "components").mkdir()
  289. (foo_dir / "foo" / "components" / "__init__.py").touch()
  290. (foo_dir / "foo" / "components" / "base.py").touch()
  291. (foo_dir / "foo" / "components" / "views.py").touch()
  292. (foo_dir / "foo" / "components" / "base.py").write_text(
  293. """
  294. import reflex as rx
  295. from foo.components import views
  296. from foo.components.views import *
  297. from .base import *
  298. def random_component():
  299. return rx.fragment()
  300. """
  301. )
  302. (foo_dir / "foo" / "foo.py").touch()
  303. (foo_dir / "foo" / "foo.py").write_text(
  304. """
  305. import reflex as rx
  306. import foo.components.base
  307. from foo.components.base import random_component
  308. class State(rx.State):
  309. pass
  310. def index():
  311. return rx.text("Hello, World!")
  312. app = rx.App()
  313. app.add_page(index)
  314. """
  315. )
  316. with chdir(temp_directory / "foo"):
  317. result = runner.invoke(cli, ["rename", "bar"])
  318. assert result.exit_code == 0
  319. assert (foo_dir / "rxconfig.py").read_text() == (
  320. """
  321. import reflex as rx
  322. config = rx.Config(
  323. app_name="bar",
  324. )
  325. """
  326. )
  327. assert (foo_dir / "bar").exists()
  328. assert not (foo_dir / "foo").exists()
  329. assert (foo_dir / "bar" / "components" / "base.py").read_text() == (
  330. """
  331. import reflex as rx
  332. from bar.components import views
  333. from bar.components.views import *
  334. from .base import *
  335. def random_component():
  336. return rx.fragment()
  337. """
  338. )
  339. assert (foo_dir / "bar" / "bar.py").exists()
  340. assert not (foo_dir / "bar" / "foo.py").exists()
  341. assert (foo_dir / "bar" / "bar.py").read_text() == (
  342. """
  343. import reflex as rx
  344. import bar.components.base
  345. from bar.components.base import random_component
  346. class State(rx.State):
  347. pass
  348. def index():
  349. return rx.text("Hello, World!")
  350. app = rx.App()
  351. app.add_page(index)
  352. """
  353. )