test_prerequisites.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import json
  2. import re
  3. import shutil
  4. import tempfile
  5. from pathlib import Path
  6. import pytest
  7. from click.testing import CliRunner
  8. from reflex.config import Config
  9. from reflex.reflex import cli
  10. from reflex.testing import chdir
  11. from reflex.utils.prerequisites import (
  12. CpuInfo,
  13. _update_next_config,
  14. cached_procedure,
  15. get_cpu_info,
  16. rename_imports_and_app_name,
  17. )
  18. runner = CliRunner()
  19. @pytest.mark.parametrize(
  20. "config, export, expected_output",
  21. [
  22. (
  23. Config(
  24. app_name="test",
  25. ),
  26. False,
  27. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  28. ),
  29. (
  30. Config(
  31. app_name="test",
  32. static_page_generation_timeout=30,
  33. ),
  34. False,
  35. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 30, devIndicators: false};',
  36. ),
  37. (
  38. Config(
  39. app_name="test",
  40. next_compression=False,
  41. ),
  42. False,
  43. 'module.exports = {basePath: "", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  44. ),
  45. (
  46. Config(
  47. app_name="test",
  48. frontend_path="/test",
  49. ),
  50. False,
  51. 'module.exports = {basePath: "/test", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  52. ),
  53. (
  54. Config(
  55. app_name="test",
  56. frontend_path="/test",
  57. next_compression=False,
  58. ),
  59. False,
  60. 'module.exports = {basePath: "/test", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  61. ),
  62. (
  63. Config(
  64. app_name="test",
  65. ),
  66. True,
  67. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false, output: "export", distDir: "_static"};',
  68. ),
  69. (
  70. Config(
  71. app_name="test",
  72. next_dev_indicators=True,
  73. ),
  74. True,
  75. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, output: "export", distDir: "_static"};',
  76. ),
  77. ],
  78. )
  79. def test_update_next_config(config, export, expected_output):
  80. output = _update_next_config(config, export=export)
  81. assert output == expected_output
  82. @pytest.mark.parametrize(
  83. ("transpile_packages", "expected_transpile_packages"),
  84. (
  85. (
  86. ["foo", "@bar/baz"],
  87. ["@bar/baz", "foo"],
  88. ),
  89. (
  90. ["foo", "@bar/baz", "foo", "@bar/baz@3.2.1"],
  91. ["@bar/baz", "foo"],
  92. ),
  93. (["@bar/baz", {"name": "foo"}], ["@bar/baz", "foo"]),
  94. (["@bar/baz", {"name": "@foo/baz"}], ["@bar/baz", "@foo/baz"]),
  95. ),
  96. )
  97. def test_transpile_packages(transpile_packages, expected_transpile_packages):
  98. output = _update_next_config(
  99. Config(app_name="test"),
  100. transpile_packages=transpile_packages,
  101. )
  102. transpile_packages_match = re.search(r"transpilePackages: (\[.*?\])", output)
  103. transpile_packages_json = transpile_packages_match.group(1) # pyright: ignore [reportOptionalMemberAccess]
  104. actual_transpile_packages = sorted(json.loads(transpile_packages_json))
  105. assert actual_transpile_packages == expected_transpile_packages
  106. def test_cached_procedure():
  107. call_count = 0
  108. @cached_procedure(tempfile.mktemp(), payload_fn=lambda: "constant")
  109. def _function_with_no_args():
  110. nonlocal call_count
  111. call_count += 1
  112. _function_with_no_args()
  113. assert call_count == 1
  114. _function_with_no_args()
  115. assert call_count == 1
  116. call_count = 0
  117. @cached_procedure(
  118. cache_file=tempfile.mktemp(),
  119. payload_fn=lambda *args, **kwargs: f"{repr(args), repr(kwargs)}",
  120. )
  121. def _function_with_some_args(*args, **kwargs):
  122. nonlocal call_count
  123. call_count += 1
  124. _function_with_some_args(1, y=2)
  125. assert call_count == 1
  126. _function_with_some_args(1, y=2)
  127. assert call_count == 1
  128. _function_with_some_args(100, y=300)
  129. assert call_count == 2
  130. _function_with_some_args(100, y=300)
  131. assert call_count == 2
  132. call_count = 0
  133. @cached_procedure(
  134. cache_file=None, cache_file_fn=tempfile.mktemp, payload_fn=lambda: "constant"
  135. )
  136. def _function_with_no_args_fn():
  137. nonlocal call_count
  138. call_count += 1
  139. _function_with_no_args_fn()
  140. assert call_count == 1
  141. _function_with_no_args_fn()
  142. assert call_count == 2
  143. def test_get_cpu_info():
  144. cpu_info = get_cpu_info()
  145. assert cpu_info is not None
  146. assert isinstance(cpu_info, CpuInfo)
  147. assert cpu_info.model_name is not None
  148. for attr in ("manufacturer_id", "model_name", "address_width"):
  149. value = getattr(cpu_info, attr)
  150. assert value.strip() if attr != "address_width" else value
  151. @pytest.fixture
  152. def temp_directory():
  153. temp_dir = tempfile.mkdtemp()
  154. yield Path(temp_dir)
  155. shutil.rmtree(temp_dir)
  156. @pytest.mark.parametrize(
  157. "config_code,expected",
  158. [
  159. ("rx.Config(app_name='old_name')", 'rx.Config(app_name="new_name")'),
  160. ('rx.Config(app_name="old_name")', 'rx.Config(app_name="new_name")'),
  161. ("rx.Config('old_name')", 'rx.Config("new_name")'),
  162. ('rx.Config("old_name")', 'rx.Config("new_name")'),
  163. ],
  164. )
  165. def test_rename_imports_and_app_name(temp_directory, config_code, expected):
  166. file_path = temp_directory / "rxconfig.py"
  167. content = f"""
  168. config = {config_code}
  169. """
  170. file_path.write_text(content)
  171. rename_imports_and_app_name(file_path, "old_name", "new_name")
  172. updated_content = file_path.read_text()
  173. expected_content = f"""
  174. config = {expected}
  175. """
  176. assert updated_content == expected_content
  177. def test_regex_edge_cases(temp_directory):
  178. file_path = temp_directory / "example.py"
  179. content = """
  180. from old_name.module import something
  181. import old_name
  182. from old_name import something_else as alias
  183. from old_name
  184. """
  185. file_path.write_text(content)
  186. rename_imports_and_app_name(file_path, "old_name", "new_name")
  187. updated_content = file_path.read_text()
  188. expected_content = """
  189. from new_name.module import something
  190. import new_name
  191. from new_name import something_else as alias
  192. from new_name
  193. """
  194. assert updated_content == expected_content
  195. def test_cli_rename_command(temp_directory):
  196. foo_dir = temp_directory / "foo"
  197. foo_dir.mkdir()
  198. (foo_dir / "__init__").touch()
  199. (foo_dir / ".web").mkdir()
  200. (foo_dir / "assets").mkdir()
  201. (foo_dir / "foo").mkdir()
  202. (foo_dir / "foo" / "__init__.py").touch()
  203. (foo_dir / "rxconfig.py").touch()
  204. (foo_dir / "rxconfig.py").write_text(
  205. """
  206. import reflex as rx
  207. config = rx.Config(
  208. app_name="foo",
  209. )
  210. """
  211. )
  212. (foo_dir / "foo" / "components").mkdir()
  213. (foo_dir / "foo" / "components" / "__init__.py").touch()
  214. (foo_dir / "foo" / "components" / "base.py").touch()
  215. (foo_dir / "foo" / "components" / "views.py").touch()
  216. (foo_dir / "foo" / "components" / "base.py").write_text(
  217. """
  218. import reflex as rx
  219. from foo.components import views
  220. from foo.components.views import *
  221. from .base import *
  222. def random_component():
  223. return rx.fragment()
  224. """
  225. )
  226. (foo_dir / "foo" / "foo.py").touch()
  227. (foo_dir / "foo" / "foo.py").write_text(
  228. """
  229. import reflex as rx
  230. import foo.components.base
  231. from foo.components.base import random_component
  232. class State(rx.State):
  233. pass
  234. def index():
  235. return rx.text("Hello, World!")
  236. app = rx.App()
  237. app.add_page(index)
  238. """
  239. )
  240. with chdir(temp_directory / "foo"):
  241. result = runner.invoke(cli, ["rename", "bar"])
  242. assert result.exit_code == 0, result.output
  243. assert (foo_dir / "rxconfig.py").read_text() == (
  244. """
  245. import reflex as rx
  246. config = rx.Config(
  247. app_name="bar",
  248. )
  249. """
  250. )
  251. assert (foo_dir / "bar").exists()
  252. assert not (foo_dir / "foo").exists()
  253. assert (foo_dir / "bar" / "components" / "base.py").read_text() == (
  254. """
  255. import reflex as rx
  256. from bar.components import views
  257. from bar.components.views import *
  258. from .base import *
  259. def random_component():
  260. return rx.fragment()
  261. """
  262. )
  263. assert (foo_dir / "bar" / "bar.py").exists()
  264. assert not (foo_dir / "bar" / "foo.py").exists()
  265. assert (foo_dir / "bar" / "bar.py").read_text() == (
  266. """
  267. import reflex as rx
  268. import bar.components.base
  269. from bar.components.base import random_component
  270. class State(rx.State):
  271. pass
  272. def index():
  273. return rx.text("Hello, World!")
  274. app = rx.App()
  275. app.add_page(index)
  276. """
  277. )