test_prerequisites.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import json
  2. import re
  3. import shutil
  4. import tempfile
  5. from pathlib import Path
  6. import pytest
  7. from typer.testing import CliRunner
  8. from reflex.config import Config
  9. from reflex.reflex import cli
  10. from reflex.testing import chdir
  11. from reflex.utils.prerequisites import (
  12. CpuInfo,
  13. _update_next_config,
  14. cached_procedure,
  15. get_cpu_info,
  16. rename_imports_and_app_name,
  17. )
  18. runner = CliRunner()
  19. @pytest.mark.parametrize(
  20. "config, export, expected_output",
  21. [
  22. (
  23. Config(
  24. app_name="test",
  25. ),
  26. False,
  27. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  28. ),
  29. (
  30. Config(
  31. app_name="test",
  32. static_page_generation_timeout=30,
  33. ),
  34. False,
  35. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 30, devIndicators: false};',
  36. ),
  37. (
  38. Config(
  39. app_name="test",
  40. next_compression=False,
  41. ),
  42. False,
  43. 'module.exports = {basePath: "", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  44. ),
  45. (
  46. Config(
  47. app_name="test",
  48. frontend_path="/test",
  49. ),
  50. False,
  51. 'module.exports = {basePath: "/test", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  52. ),
  53. (
  54. Config(
  55. app_name="test",
  56. frontend_path="/test",
  57. next_compression=False,
  58. ),
  59. False,
  60. 'module.exports = {basePath: "/test", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false};',
  61. ),
  62. (
  63. Config(
  64. app_name="test",
  65. ),
  66. True,
  67. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: false, output: "export", distDir: "_static"};',
  68. ),
  69. (
  70. Config(
  71. app_name="test",
  72. next_dev_indicators=True,
  73. ),
  74. True,
  75. 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, devIndicators: true, output: "export", distDir: "_static"};',
  76. ),
  77. ],
  78. )
  79. def test_update_next_config(config, export, expected_output):
  80. output = _update_next_config(config, export=export)
  81. assert output == expected_output
  82. @pytest.mark.parametrize(
  83. ("transpile_packages", "expected_transpile_packages"),
  84. (
  85. (
  86. ["foo", "@bar/baz"],
  87. ["@bar/baz", "foo"],
  88. ),
  89. (
  90. ["foo", "@bar/baz", "foo", "@bar/baz@3.2.1"],
  91. ["@bar/baz", "foo"],
  92. ),
  93. ),
  94. )
  95. def test_transpile_packages(transpile_packages, expected_transpile_packages):
  96. output = _update_next_config(
  97. Config(app_name="test"),
  98. transpile_packages=transpile_packages,
  99. )
  100. transpile_packages_match = re.search(r"transpilePackages: (\[.*?\])", output)
  101. transpile_packages_json = transpile_packages_match.group(1) # pyright: ignore [reportOptionalMemberAccess]
  102. actual_transpile_packages = sorted(json.loads(transpile_packages_json))
  103. assert actual_transpile_packages == expected_transpile_packages
  104. def test_cached_procedure():
  105. call_count = 0
  106. @cached_procedure(tempfile.mktemp(), payload_fn=lambda: "constant")
  107. def _function_with_no_args():
  108. nonlocal call_count
  109. call_count += 1
  110. _function_with_no_args()
  111. assert call_count == 1
  112. _function_with_no_args()
  113. assert call_count == 1
  114. call_count = 0
  115. @cached_procedure(
  116. cache_file=tempfile.mktemp(),
  117. payload_fn=lambda *args, **kwargs: f"{repr(args), repr(kwargs)}",
  118. )
  119. def _function_with_some_args(*args, **kwargs):
  120. nonlocal call_count
  121. call_count += 1
  122. _function_with_some_args(1, y=2)
  123. assert call_count == 1
  124. _function_with_some_args(1, y=2)
  125. assert call_count == 1
  126. _function_with_some_args(100, y=300)
  127. assert call_count == 2
  128. _function_with_some_args(100, y=300)
  129. assert call_count == 2
  130. call_count = 0
  131. @cached_procedure(
  132. cache_file=None, cache_file_fn=tempfile.mktemp, payload_fn=lambda: "constant"
  133. )
  134. def _function_with_no_args_fn():
  135. nonlocal call_count
  136. call_count += 1
  137. _function_with_no_args_fn()
  138. assert call_count == 1
  139. _function_with_no_args_fn()
  140. assert call_count == 2
  141. def test_get_cpu_info():
  142. cpu_info = get_cpu_info()
  143. assert cpu_info is not None
  144. assert isinstance(cpu_info, CpuInfo)
  145. assert cpu_info.model_name is not None
  146. for attr in ("manufacturer_id", "model_name", "address_width"):
  147. value = getattr(cpu_info, attr)
  148. assert value.strip() if attr != "address_width" else value
  149. @pytest.fixture
  150. def temp_directory():
  151. temp_dir = tempfile.mkdtemp()
  152. yield Path(temp_dir)
  153. shutil.rmtree(temp_dir)
  154. @pytest.mark.parametrize(
  155. "config_code,expected",
  156. [
  157. ("rx.Config(app_name='old_name')", 'rx.Config(app_name="new_name")'),
  158. ('rx.Config(app_name="old_name")', 'rx.Config(app_name="new_name")'),
  159. ("rx.Config('old_name')", 'rx.Config("new_name")'),
  160. ('rx.Config("old_name")', 'rx.Config("new_name")'),
  161. ],
  162. )
  163. def test_rename_imports_and_app_name(temp_directory, config_code, expected):
  164. file_path = temp_directory / "rxconfig.py"
  165. content = f"""
  166. config = {config_code}
  167. """
  168. file_path.write_text(content)
  169. rename_imports_and_app_name(file_path, "old_name", "new_name")
  170. updated_content = file_path.read_text()
  171. expected_content = f"""
  172. config = {expected}
  173. """
  174. assert updated_content == expected_content
  175. def test_regex_edge_cases(temp_directory):
  176. file_path = temp_directory / "example.py"
  177. content = """
  178. from old_name.module import something
  179. import old_name
  180. from old_name import something_else as alias
  181. from old_name
  182. """
  183. file_path.write_text(content)
  184. rename_imports_and_app_name(file_path, "old_name", "new_name")
  185. updated_content = file_path.read_text()
  186. expected_content = """
  187. from new_name.module import something
  188. import new_name
  189. from new_name import something_else as alias
  190. from new_name
  191. """
  192. assert updated_content == expected_content
  193. def test_cli_rename_command(temp_directory):
  194. foo_dir = temp_directory / "foo"
  195. foo_dir.mkdir()
  196. (foo_dir / "__init__").touch()
  197. (foo_dir / ".web").mkdir()
  198. (foo_dir / "assets").mkdir()
  199. (foo_dir / "foo").mkdir()
  200. (foo_dir / "foo" / "__init__.py").touch()
  201. (foo_dir / "rxconfig.py").touch()
  202. (foo_dir / "rxconfig.py").write_text(
  203. """
  204. import reflex as rx
  205. config = rx.Config(
  206. app_name="foo",
  207. )
  208. """
  209. )
  210. (foo_dir / "foo" / "components").mkdir()
  211. (foo_dir / "foo" / "components" / "__init__.py").touch()
  212. (foo_dir / "foo" / "components" / "base.py").touch()
  213. (foo_dir / "foo" / "components" / "views.py").touch()
  214. (foo_dir / "foo" / "components" / "base.py").write_text(
  215. """
  216. import reflex as rx
  217. from foo.components import views
  218. from foo.components.views import *
  219. from .base import *
  220. def random_component():
  221. return rx.fragment()
  222. """
  223. )
  224. (foo_dir / "foo" / "foo.py").touch()
  225. (foo_dir / "foo" / "foo.py").write_text(
  226. """
  227. import reflex as rx
  228. import foo.components.base
  229. from foo.components.base import random_component
  230. class State(rx.State):
  231. pass
  232. def index():
  233. return rx.text("Hello, World!")
  234. app = rx.App()
  235. app.add_page(index)
  236. """
  237. )
  238. with chdir(temp_directory / "foo"):
  239. result = runner.invoke(cli, ["rename", "bar"])
  240. assert result.exit_code == 0
  241. assert (foo_dir / "rxconfig.py").read_text() == (
  242. """
  243. import reflex as rx
  244. config = rx.Config(
  245. app_name="bar",
  246. )
  247. """
  248. )
  249. assert (foo_dir / "bar").exists()
  250. assert not (foo_dir / "foo").exists()
  251. assert (foo_dir / "bar" / "components" / "base.py").read_text() == (
  252. """
  253. import reflex as rx
  254. from bar.components import views
  255. from bar.components.views import *
  256. from .base import *
  257. def random_component():
  258. return rx.fragment()
  259. """
  260. )
  261. assert (foo_dir / "bar" / "bar.py").exists()
  262. assert not (foo_dir / "bar" / "foo.py").exists()
  263. assert (foo_dir / "bar" / "bar.py").read_text() == (
  264. """
  265. import reflex as rx
  266. import bar.components.base
  267. from bar.components.base import random_component
  268. class State(rx.State):
  269. pass
  270. def index():
  271. return rx.text("Hello, World!")
  272. app = rx.App()
  273. app.add_page(index)
  274. """
  275. )