test_prerequisites.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. import json
  2. import re
  3. import shutil
  4. import tempfile
  5. from pathlib import Path
  6. from unittest.mock import Mock, mock_open
  7. import pytest
  8. from typer.testing import CliRunner
  9. from reflex import constants
  10. from reflex.config import Config
  11. from reflex.reflex import cli
  12. from reflex.testing import chdir
  13. from reflex.utils.prerequisites import (
  14. CpuInfo,
  15. _update_next_config,
  16. cached_procedure,
  17. get_cpu_info,
  18. initialize_requirements_txt,
  19. rename_imports_and_app_name,
  20. )
  21. runner = CliRunner()
  22. @pytest.mark.parametrize(
  23. "config, export, expected_output",
  24. [
  25. (
  26. Config(
  27. app_name="test",
  28. ),
  29. False,
  30. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
  31. ),
  32. (
  33. Config(
  34. app_name="test",
  35. static_page_generation_timeout=30,
  36. ),
  37. False,
  38. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 30};',
  39. ),
  40. (
  41. Config(
  42. app_name="test",
  43. next_compression=False,
  44. ),
  45. False,
  46. 'module.exports = {basePath: "", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60};',
  47. ),
  48. (
  49. Config(
  50. app_name="test",
  51. frontend_path="/test",
  52. ),
  53. False,
  54. 'module.exports = {basePath: "/test", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
  55. ),
  56. (
  57. Config(
  58. app_name="test",
  59. frontend_path="/test",
  60. next_compression=False,
  61. ),
  62. False,
  63. 'module.exports = {basePath: "/test", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60};',
  64. ),
  65. (
  66. Config(
  67. app_name="test",
  68. ),
  69. True,
  70. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, output: "export", distDir: "_static"};',
  71. ),
  72. ],
  73. )
  74. def test_update_next_config(config, export, expected_output):
  75. output = _update_next_config(config, export=export)
  76. assert output == expected_output
  77. @pytest.mark.parametrize(
  78. ("transpile_packages", "expected_transpile_packages"),
  79. (
  80. (
  81. ["foo", "@bar/baz"],
  82. ["@bar/baz", "foo"],
  83. ),
  84. (
  85. ["foo", "@bar/baz", "foo", "@bar/baz@3.2.1"],
  86. ["@bar/baz", "foo"],
  87. ),
  88. ),
  89. )
  90. def test_transpile_packages(transpile_packages, expected_transpile_packages):
  91. output = _update_next_config(
  92. Config(app_name="test"),
  93. transpile_packages=transpile_packages,
  94. )
  95. transpile_packages_match = re.search(r"transpilePackages: (\[.*?\])", output)
  96. transpile_packages_json = transpile_packages_match.group(1) # pyright: ignore [reportOptionalMemberAccess]
  97. actual_transpile_packages = sorted(json.loads(transpile_packages_json))
  98. assert actual_transpile_packages == expected_transpile_packages
  99. def test_initialize_requirements_txt_no_op(mocker):
  100. # File exists, reflex is included, do nothing
  101. mocker.patch("pathlib.Path.exists", return_value=True)
  102. mocker.patch(
  103. "charset_normalizer.from_path",
  104. return_value=Mock(best=lambda: Mock(encoding="utf-8")),
  105. )
  106. mock_fp_touch = mocker.patch("pathlib.Path.touch")
  107. open_mock = mock_open(read_data="reflex==0.6.7")
  108. mocker.patch("pathlib.Path.open", open_mock)
  109. initialize_requirements_txt()
  110. assert open_mock.call_count == 1
  111. assert open_mock.call_args.kwargs["encoding"] == "utf-8"
  112. assert open_mock().write.call_count == 0
  113. mock_fp_touch.assert_not_called()
  114. def test_initialize_requirements_txt_missing_reflex(mocker):
  115. # File exists, reflex is not included, add reflex
  116. mocker.patch("pathlib.Path.exists", return_value=True)
  117. mocker.patch(
  118. "charset_normalizer.from_path",
  119. return_value=Mock(best=lambda: Mock(encoding="utf-8")),
  120. )
  121. open_mock = mock_open(read_data="random-package=1.2.3")
  122. mocker.patch("pathlib.Path.open", open_mock)
  123. initialize_requirements_txt()
  124. # Currently open for read, then open for append
  125. assert open_mock.call_count == 2
  126. for call_args in open_mock.call_args_list:
  127. assert call_args.kwargs["encoding"] == "utf-8"
  128. assert (
  129. open_mock().write.call_args[0][0]
  130. == f"\n{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
  131. )
  132. def test_initialize_requirements_txt_not_exist(mocker):
  133. # File does not exist, create file with reflex
  134. mocker.patch("pathlib.Path.exists", return_value=False)
  135. open_mock = mock_open()
  136. mocker.patch("pathlib.Path.open", open_mock)
  137. initialize_requirements_txt()
  138. assert open_mock.call_count == 2
  139. # By default, use utf-8 encoding
  140. for call_args in open_mock.call_args_list:
  141. assert call_args.kwargs["encoding"] == "utf-8"
  142. assert open_mock().write.call_count == 1
  143. assert (
  144. open_mock().write.call_args[0][0]
  145. == f"{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
  146. )
  147. def test_requirements_txt_cannot_detect_encoding(mocker):
  148. mocker.patch("pathlib.Path.exists", return_value=True)
  149. mock_open = mocker.patch("builtins.open")
  150. mocker.patch(
  151. "charset_normalizer.from_path",
  152. return_value=Mock(best=lambda: None),
  153. )
  154. initialize_requirements_txt()
  155. mock_open.assert_not_called()
  156. def test_requirements_txt_other_encoding(mocker):
  157. mocker.patch("pathlib.Path.exists", return_value=True)
  158. mocker.patch(
  159. "charset_normalizer.from_path",
  160. return_value=Mock(best=lambda: Mock(encoding="utf-16")),
  161. )
  162. initialize_requirements_txt()
  163. open_mock = mock_open(read_data="random-package=1.2.3")
  164. mocker.patch("pathlib.Path.open", open_mock)
  165. initialize_requirements_txt()
  166. # Currently open for read, then open for append
  167. assert open_mock.call_count == 2
  168. for call_args in open_mock.call_args_list:
  169. assert call_args.kwargs["encoding"] == "utf-16"
  170. assert (
  171. open_mock().write.call_args[0][0]
  172. == f"\n{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
  173. )
  174. def test_cached_procedure():
  175. call_count = 0
  176. @cached_procedure(tempfile.mktemp(), payload_fn=lambda: "constant")
  177. def _function_with_no_args():
  178. nonlocal call_count
  179. call_count += 1
  180. _function_with_no_args()
  181. assert call_count == 1
  182. _function_with_no_args()
  183. assert call_count == 1
  184. call_count = 0
  185. @cached_procedure(
  186. tempfile.mktemp(),
  187. payload_fn=lambda *args, **kwargs: f"{repr(args), repr(kwargs)}",
  188. )
  189. def _function_with_some_args(*args, **kwargs):
  190. nonlocal call_count
  191. call_count += 1
  192. _function_with_some_args(1, y=2)
  193. assert call_count == 1
  194. _function_with_some_args(1, y=2)
  195. assert call_count == 1
  196. _function_with_some_args(100, y=300)
  197. assert call_count == 2
  198. _function_with_some_args(100, y=300)
  199. assert call_count == 2
  200. def test_get_cpu_info():
  201. cpu_info = get_cpu_info()
  202. assert cpu_info is not None
  203. assert isinstance(cpu_info, CpuInfo)
  204. assert cpu_info.model_name is not None
  205. for attr in ("manufacturer_id", "model_name", "address_width"):
  206. value = getattr(cpu_info, attr)
  207. assert value.strip() if attr != "address_width" else value
  208. @pytest.fixture
  209. def temp_directory():
  210. temp_dir = tempfile.mkdtemp()
  211. yield Path(temp_dir)
  212. shutil.rmtree(temp_dir)
  213. @pytest.mark.parametrize(
  214. "config_code,expected",
  215. [
  216. ("rx.Config(app_name='old_name')", 'rx.Config(app_name="new_name")'),
  217. ('rx.Config(app_name="old_name")', 'rx.Config(app_name="new_name")'),
  218. ("rx.Config('old_name')", 'rx.Config("new_name")'),
  219. ('rx.Config("old_name")', 'rx.Config("new_name")'),
  220. ],
  221. )
  222. def test_rename_imports_and_app_name(temp_directory, config_code, expected):
  223. file_path = temp_directory / "rxconfig.py"
  224. content = f"""
  225. config = {config_code}
  226. """
  227. file_path.write_text(content)
  228. rename_imports_and_app_name(file_path, "old_name", "new_name")
  229. updated_content = file_path.read_text()
  230. expected_content = f"""
  231. config = {expected}
  232. """
  233. assert updated_content == expected_content
  234. def test_regex_edge_cases(temp_directory):
  235. file_path = temp_directory / "example.py"
  236. content = """
  237. from old_name.module import something
  238. import old_name
  239. from old_name import something_else as alias
  240. from old_name
  241. """
  242. file_path.write_text(content)
  243. rename_imports_and_app_name(file_path, "old_name", "new_name")
  244. updated_content = file_path.read_text()
  245. expected_content = """
  246. from new_name.module import something
  247. import new_name
  248. from new_name import something_else as alias
  249. from new_name
  250. """
  251. assert updated_content == expected_content
  252. def test_cli_rename_command(temp_directory):
  253. foo_dir = temp_directory / "foo"
  254. foo_dir.mkdir()
  255. (foo_dir / "__init__").touch()
  256. (foo_dir / ".web").mkdir()
  257. (foo_dir / "assets").mkdir()
  258. (foo_dir / "foo").mkdir()
  259. (foo_dir / "foo" / "__init__.py").touch()
  260. (foo_dir / "rxconfig.py").touch()
  261. (foo_dir / "rxconfig.py").write_text(
  262. """
  263. import reflex as rx
  264. config = rx.Config(
  265. app_name="foo",
  266. )
  267. """
  268. )
  269. (foo_dir / "foo" / "components").mkdir()
  270. (foo_dir / "foo" / "components" / "__init__.py").touch()
  271. (foo_dir / "foo" / "components" / "base.py").touch()
  272. (foo_dir / "foo" / "components" / "views.py").touch()
  273. (foo_dir / "foo" / "components" / "base.py").write_text(
  274. """
  275. import reflex as rx
  276. from foo.components import views
  277. from foo.components.views import *
  278. from .base import *
  279. def random_component():
  280. return rx.fragment()
  281. """
  282. )
  283. (foo_dir / "foo" / "foo.py").touch()
  284. (foo_dir / "foo" / "foo.py").write_text(
  285. """
  286. import reflex as rx
  287. import foo.components.base
  288. from foo.components.base import random_component
  289. class State(rx.State):
  290. pass
  291. def index():
  292. return rx.text("Hello, World!")
  293. app = rx.App()
  294. app.add_page(index)
  295. """
  296. )
  297. with chdir(temp_directory / "foo"):
  298. result = runner.invoke(cli, ["rename", "bar"])
  299. assert result.exit_code == 0
  300. assert (foo_dir / "rxconfig.py").read_text() == (
  301. """
  302. import reflex as rx
  303. config = rx.Config(
  304. app_name="bar",
  305. )
  306. """
  307. )
  308. assert (foo_dir / "bar").exists()
  309. assert not (foo_dir / "foo").exists()
  310. assert (foo_dir / "bar" / "components" / "base.py").read_text() == (
  311. """
  312. import reflex as rx
  313. from bar.components import views
  314. from bar.components.views import *
  315. from .base import *
  316. def random_component():
  317. return rx.fragment()
  318. """
  319. )
  320. assert (foo_dir / "bar" / "bar.py").exists()
  321. assert not (foo_dir / "bar" / "foo.py").exists()
  322. assert (foo_dir / "bar" / "bar.py").read_text() == (
  323. """
  324. import reflex as rx
  325. import bar.components.base
  326. from bar.components.base import random_component
  327. class State(rx.State):
  328. pass
  329. def index():
  330. return rx.text("Hello, World!")
  331. app = rx.App()
  332. app.add_page(index)
  333. """
  334. )