testing.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. """reflex.testing - tools for testing reflex apps."""
  2. from __future__ import annotations
  3. import contextlib
  4. import dataclasses
  5. import inspect
  6. import os
  7. import pathlib
  8. import platform
  9. import re
  10. import signal
  11. import socket
  12. import subprocess
  13. import textwrap
  14. import threading
  15. import time
  16. import types
  17. from typing import (
  18. TYPE_CHECKING,
  19. Any,
  20. Callable,
  21. Coroutine,
  22. Optional,
  23. Type,
  24. TypeVar,
  25. Union,
  26. cast,
  27. )
  28. import psutil
  29. import uvicorn
  30. import reflex
  31. import reflex.reflex
  32. import reflex.utils.build
  33. import reflex.utils.exec
  34. import reflex.utils.prerequisites
  35. import reflex.utils.processes
  36. from reflex.app import EventNamespace
  37. try:
  38. from selenium import webdriver # pyright: ignore [reportMissingImports]
  39. from selenium.webdriver.remote.webdriver import ( # pyright: ignore [reportMissingImports]
  40. WebDriver,
  41. )
  42. if TYPE_CHECKING:
  43. from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
  44. WebElement,
  45. )
  46. has_selenium = True
  47. except ImportError:
  48. has_selenium = False
  49. DEFAULT_TIMEOUT = 10
  50. POLL_INTERVAL = 0.25
  51. FRONTEND_LISTENING_MESSAGE = re.compile(r"ready started server on.*, url: (.*:[0-9]+)$")
  52. FRONTEND_POPEN_ARGS = {}
  53. T = TypeVar("T")
  54. TimeoutType = Optional[Union[int, float]]
  55. if platform.system == "Windows":
  56. FRONTEND_POPEN_ARGS["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore
  57. else:
  58. FRONTEND_POPEN_ARGS["start_new_session"] = True
  59. # borrowed from py3.11
  60. class chdir(contextlib.AbstractContextManager):
  61. """Non thread-safe context manager to change the current working directory."""
  62. def __init__(self, path):
  63. """Prepare contextmanager.
  64. Args:
  65. path: the path to change to
  66. """
  67. self.path = path
  68. self._old_cwd = []
  69. def __enter__(self):
  70. """Save current directory and perform chdir."""
  71. self._old_cwd.append(os.getcwd())
  72. os.chdir(self.path)
  73. def __exit__(self, *excinfo):
  74. """Change back to previous directory on stack.
  75. Args:
  76. excinfo: sys.exc_info captured in the context block
  77. """
  78. os.chdir(self._old_cwd.pop())
  79. @dataclasses.dataclass
  80. class AppHarness:
  81. """AppHarness executes a reflex app in-process for testing."""
  82. app_name: str
  83. app_source: Optional[types.FunctionType | types.ModuleType]
  84. app_path: pathlib.Path
  85. app_module_path: pathlib.Path
  86. app_module: Optional[types.ModuleType] = None
  87. app_instance: Optional[reflex.App] = None
  88. frontend_process: Optional[subprocess.Popen] = None
  89. frontend_url: Optional[str] = None
  90. backend_thread: Optional[threading.Thread] = None
  91. backend: Optional[uvicorn.Server] = None
  92. _frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
  93. @classmethod
  94. def create(
  95. cls,
  96. root: pathlib.Path,
  97. app_source: Optional[types.FunctionType | types.ModuleType] = None,
  98. app_name: Optional[str] = None,
  99. ) -> "AppHarness":
  100. """Create an AppHarness instance at root.
  101. Args:
  102. root: the directory that will contain the app under test.
  103. app_source: if specified, the source code from this function or module is used
  104. as the main module for the app. If unspecified, then root must already
  105. contain a working reflex app and will be used directly.
  106. app_name: provide the name of the app, otherwise will be derived from app_source or root.
  107. Returns:
  108. AppHarness instance
  109. """
  110. if app_name is None:
  111. if app_source is None:
  112. app_name = root.name.lower()
  113. else:
  114. app_name = app_source.__name__.lower()
  115. return cls(
  116. app_name=app_name,
  117. app_source=app_source,
  118. app_path=root,
  119. app_module_path=root / app_name / f"{app_name}.py",
  120. )
  121. def _initialize_app(self):
  122. os.environ["TELEMETRY_ENABLED"] = "" # disable telemetry reporting for tests
  123. self.app_path.mkdir(parents=True, exist_ok=True)
  124. if self.app_source is not None:
  125. # get the source from a function or module object
  126. source_code = textwrap.dedent(
  127. "".join(inspect.getsource(self.app_source).splitlines(True)[1:]),
  128. )
  129. with chdir(self.app_path):
  130. reflex.reflex.init(
  131. name=self.app_name,
  132. template=reflex.constants.Template.DEFAULT,
  133. loglevel=reflex.constants.LogLevel.INFO,
  134. )
  135. self.app_module_path.write_text(source_code)
  136. with chdir(self.app_path):
  137. # ensure config is reloaded when testing different app
  138. reflex.config.get_config(reload=True)
  139. self.app_module = reflex.utils.prerequisites.get_app()
  140. self.app_instance = self.app_module.app
  141. def _start_backend(self):
  142. if self.app_instance is None:
  143. raise RuntimeError("App was not initialized.")
  144. self.backend = uvicorn.Server(
  145. uvicorn.Config(
  146. app=self.app_instance.api,
  147. host="127.0.0.1",
  148. port=0,
  149. )
  150. )
  151. self.backend_thread = threading.Thread(target=self.backend.run)
  152. self.backend_thread.start()
  153. def _start_frontend(self):
  154. # Set up the frontend.
  155. with chdir(self.app_path):
  156. config = reflex.config.get_config()
  157. config.api_url = "http://{0}:{1}".format(
  158. *self._poll_for_servers().getsockname(),
  159. )
  160. reflex.utils.build.setup_frontend(self.app_path)
  161. # Start the frontend.
  162. self.frontend_process = reflex.utils.processes.new_process(
  163. [reflex.utils.prerequisites.get_package_manager(), "run", "dev"],
  164. cwd=self.app_path / reflex.constants.WEB_DIR,
  165. env={"PORT": "0"},
  166. **FRONTEND_POPEN_ARGS,
  167. )
  168. def _wait_frontend(self):
  169. while self.frontend_url is None:
  170. line = (
  171. self.frontend_process.stdout.readline() # pyright: ignore [reportOptionalMemberAccess]
  172. )
  173. if not line:
  174. break
  175. print(line) # for pytest diagnosis
  176. m = FRONTEND_LISTENING_MESSAGE.search(line)
  177. if m is not None:
  178. self.frontend_url = m.group(1)
  179. break
  180. if self.frontend_url is None:
  181. raise RuntimeError("Frontend did not start")
  182. def start(self) -> "AppHarness":
  183. """Start the backend in a new thread and dev frontend as a separate process.
  184. Returns:
  185. self
  186. """
  187. self._initialize_app()
  188. self._start_backend()
  189. self._start_frontend()
  190. self._wait_frontend()
  191. return self
  192. def __enter__(self) -> "AppHarness":
  193. """Contextmanager protocol for `start()`.
  194. Returns:
  195. Instance of AppHarness after calling start()
  196. """
  197. return self.start()
  198. def stop(self) -> None:
  199. """Stop the frontend and backend servers."""
  200. if self.backend is not None:
  201. self.backend.should_exit = True
  202. if self.frontend_process is not None:
  203. # https://stackoverflow.com/a/70565806
  204. frontend_children = psutil.Process(self.frontend_process.pid).children(
  205. recursive=True,
  206. )
  207. if platform.system() == "Windows":
  208. self.frontend_process.terminate()
  209. else:
  210. pgrp = os.getpgid(self.frontend_process.pid)
  211. os.killpg(pgrp, signal.SIGTERM)
  212. # kill any remaining child processes
  213. for child in frontend_children:
  214. # It's okay if the process is already gone.
  215. with contextlib.suppress(psutil.NoSuchProcess):
  216. child.terminate()
  217. _, still_alive = psutil.wait_procs(frontend_children, timeout=3)
  218. for child in still_alive:
  219. # It's okay if the process is already gone.
  220. with contextlib.suppress(psutil.NoSuchProcess):
  221. child.kill()
  222. # wait for main process to exit
  223. self.frontend_process.communicate()
  224. if self.backend_thread is not None:
  225. self.backend_thread.join()
  226. for driver in self._frontends:
  227. driver.quit()
  228. def __exit__(self, *excinfo) -> None:
  229. """Contextmanager protocol for `stop()`.
  230. Args:
  231. excinfo: sys.exc_info captured in the context block
  232. """
  233. self.stop()
  234. @staticmethod
  235. def _poll_for(
  236. target: Callable[[], T],
  237. timeout: TimeoutType = None,
  238. step: TimeoutType = None,
  239. ) -> T | bool:
  240. """Generic polling logic.
  241. Args:
  242. target: callable that returns truthy if polling condition is met.
  243. timeout: max polling time
  244. step: interval between checking target()
  245. Returns:
  246. return value of target() if truthy within timeout
  247. False if timeout elapses
  248. """
  249. if timeout is None:
  250. timeout = DEFAULT_TIMEOUT
  251. if step is None:
  252. step = POLL_INTERVAL
  253. deadline = time.time() + timeout
  254. while time.time() < deadline:
  255. success = target()
  256. if success:
  257. return success
  258. time.sleep(step)
  259. return False
  260. def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
  261. """Poll backend server for listening sockets.
  262. Args:
  263. timeout: how long to wait for listening socket.
  264. Returns:
  265. first active listening socket on the backend
  266. Raises:
  267. RuntimeError: when the backend hasn't started running
  268. TimeoutError: when server or sockets are not ready
  269. """
  270. if self.backend is None:
  271. raise RuntimeError("Backend is not running.")
  272. backend = self.backend
  273. # check for servers to be initialized
  274. if not self._poll_for(
  275. target=lambda: getattr(backend, "servers", False),
  276. timeout=timeout,
  277. ):
  278. raise TimeoutError("Backend servers are not initialized.")
  279. # check for sockets to be listening
  280. if not self._poll_for(
  281. target=lambda: getattr(backend.servers[0], "sockets", False),
  282. timeout=timeout,
  283. ):
  284. raise TimeoutError("Backend is not listening.")
  285. return backend.servers[0].sockets[0]
  286. def frontend(self, driver_clz: Optional[Type["WebDriver"]] = None) -> "WebDriver":
  287. """Get a selenium webdriver instance pointed at the app.
  288. Args:
  289. driver_clz: webdriver.Chrome (default), webdriver.Firefox, webdriver.Safari,
  290. webdriver.Edge, etc
  291. Returns:
  292. Instance of the given webdriver navigated to the frontend url of the app.
  293. Raises:
  294. RuntimeError: when selenium is not importable or frontend is not running
  295. """
  296. if not has_selenium:
  297. raise RuntimeError(
  298. "Frontend functionality requires `selenium` to be installed, "
  299. "and it could not be imported."
  300. )
  301. if self.frontend_url is None:
  302. raise RuntimeError("Frontend is not running.")
  303. driver = driver_clz() if driver_clz is not None else webdriver.Chrome()
  304. driver.get(self.frontend_url)
  305. self._frontends.append(driver)
  306. return driver
  307. async def emit_state_updates(self) -> list[Any]:
  308. """Send any backend state deltas to the frontend.
  309. Returns:
  310. List of awaited response from each EventNamespace.emit() call.
  311. Raises:
  312. RuntimeError: when the app hasn't started running
  313. """
  314. if self.app_instance is None or self.app_instance.sio is None:
  315. raise RuntimeError("App is not running.")
  316. event_ns: EventNamespace = cast(
  317. EventNamespace,
  318. self.app_instance.event_namespace,
  319. )
  320. pending: list[Coroutine[Any, Any, Any]] = []
  321. for state in self.app_instance.state_manager.states.values():
  322. delta = state.get_delta()
  323. if delta:
  324. update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
  325. state._clean()
  326. # Emit the event.
  327. pending.append(
  328. event_ns.emit(
  329. str(reflex.constants.SocketEvent.EVENT),
  330. update.json(),
  331. to=state.get_sid(),
  332. ),
  333. )
  334. responses = []
  335. for request in pending:
  336. responses.append(await request)
  337. return responses
  338. def poll_for_content(
  339. self,
  340. element: "WebElement",
  341. timeout: TimeoutType = None,
  342. exp_not_equal: str = "",
  343. ) -> str:
  344. """Poll element.text for change.
  345. Args:
  346. element: selenium webdriver element to check
  347. timeout: how long to poll element.text
  348. exp_not_equal: exit the polling loop when the element text does not match
  349. Returns:
  350. The element text when the polling loop exited
  351. Raises:
  352. TimeoutError: when the timeout expires before text changes
  353. """
  354. if not self._poll_for(
  355. target=lambda: element.text != exp_not_equal,
  356. timeout=timeout,
  357. ):
  358. raise TimeoutError(
  359. f"{element} content remains {exp_not_equal!r} while polling.",
  360. )
  361. return element.text
  362. def poll_for_value(
  363. self,
  364. element: "WebElement",
  365. timeout: TimeoutType = None,
  366. exp_not_equal: str = "",
  367. ) -> Optional[str]:
  368. """Poll element.get_attribute("value") for change.
  369. Args:
  370. element: selenium webdriver element to check
  371. timeout: how long to poll element value attribute
  372. exp_not_equal: exit the polling loop when the value does not match
  373. Returns:
  374. The element value when the polling loop exited
  375. Raises:
  376. TimeoutError: when the timeout expires before value changes
  377. """
  378. if not self._poll_for(
  379. target=lambda: element.get_attribute("value") != exp_not_equal,
  380. timeout=timeout,
  381. ):
  382. raise TimeoutError(
  383. f"{element} content remains {exp_not_equal!r} while polling.",
  384. )
  385. return element.get_attribute("value")
  386. def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, reflex.State]:
  387. """Poll app state_manager for any connected clients.
  388. Args:
  389. timeout: how long to wait for client states
  390. Returns:
  391. active state instances when the polling loop exited
  392. Raises:
  393. RuntimeError: when the app hasn't started running
  394. TimeoutError: when the timeout expires before any states are seen
  395. """
  396. if self.app_instance is None:
  397. raise RuntimeError("App is not running.")
  398. state_manager = self.app_instance.state_manager
  399. if not self._poll_for(
  400. target=lambda: state_manager.states,
  401. timeout=timeout,
  402. ):
  403. raise TimeoutError("No states were observed while polling.")
  404. return state_manager.states