testing.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014
  1. """reflex.testing - tools for testing reflex apps."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import dataclasses
  6. import functools
  7. import inspect
  8. import os
  9. import platform
  10. import re
  11. import signal
  12. import socket
  13. import socketserver
  14. import subprocess
  15. import sys
  16. import textwrap
  17. import threading
  18. import time
  19. import types
  20. from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
  21. from http.server import SimpleHTTPRequestHandler
  22. from pathlib import Path
  23. from typing import TYPE_CHECKING, Any, TypeVar
  24. import psutil
  25. import uvicorn
  26. import reflex
  27. import reflex.environment
  28. import reflex.reflex
  29. import reflex.utils.build
  30. import reflex.utils.exec
  31. import reflex.utils.format
  32. import reflex.utils.prerequisites
  33. import reflex.utils.processes
  34. from reflex.components.component import CustomComponent
  35. from reflex.config import get_config
  36. from reflex.environment import environment
  37. from reflex.state import (
  38. BaseState,
  39. StateManager,
  40. StateManagerDisk,
  41. StateManagerMemory,
  42. StateManagerRedis,
  43. reload_state_module,
  44. )
  45. from reflex.utils import console
  46. from reflex.utils.export import export
  47. from reflex.utils.types import ASGIApp
  48. try:
  49. from selenium import webdriver
  50. from selenium.webdriver.remote.webdriver import WebDriver
  51. if TYPE_CHECKING:
  52. from selenium.webdriver.common.options import ArgOptions
  53. from selenium.webdriver.remote.webelement import WebElement
  54. has_selenium = True
  55. except ImportError:
  56. has_selenium = False
  57. # The timeout (minutes) to check for the port.
  58. DEFAULT_TIMEOUT = 15
  59. POLL_INTERVAL = 0.25
  60. FRONTEND_POPEN_ARGS = {}
  61. T = TypeVar("T")
  62. TimeoutType = int | float | None
  63. if platform.system() == "Windows":
  64. FRONTEND_POPEN_ARGS["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP # pyright: ignore [reportAttributeAccessIssue]
  65. FRONTEND_POPEN_ARGS["shell"] = True
  66. else:
  67. FRONTEND_POPEN_ARGS["start_new_session"] = True
  68. # borrowed from py3.11
  69. class chdir(contextlib.AbstractContextManager): # noqa: N801
  70. """Non thread-safe context manager to change the current working directory."""
  71. def __init__(self, path: str | Path):
  72. """Prepare contextmanager.
  73. Args:
  74. path: the path to change to
  75. """
  76. self.path = path
  77. self._old_cwd = []
  78. def __enter__(self):
  79. """Save current directory and perform chdir."""
  80. self._old_cwd.append(Path.cwd())
  81. os.chdir(self.path)
  82. def __exit__(self, *excinfo):
  83. """Change back to previous directory on stack.
  84. Args:
  85. excinfo: sys.exc_info captured in the context block
  86. """
  87. os.chdir(self._old_cwd.pop())
  88. @dataclasses.dataclass
  89. class AppHarness:
  90. """AppHarness executes a reflex app in-process for testing."""
  91. app_name: str
  92. app_source: (
  93. Callable[[], None] | types.ModuleType | str | functools.partial[Any] | None
  94. )
  95. app_path: Path
  96. app_module_path: Path
  97. app_module: types.ModuleType | None = None
  98. app_instance: reflex.App | None = None
  99. app_asgi: ASGIApp | None = None
  100. frontend_process: subprocess.Popen | None = None
  101. frontend_url: str | None = None
  102. frontend_output_thread: threading.Thread | None = None
  103. backend_thread: threading.Thread | None = None
  104. backend: uvicorn.Server | None = None
  105. state_manager: StateManager | None = None
  106. _frontends: list[WebDriver] = dataclasses.field(default_factory=list)
  107. @classmethod
  108. def create(
  109. cls,
  110. root: Path,
  111. app_source: (
  112. Callable[[], None] | types.ModuleType | str | functools.partial[Any] | None
  113. ) = None,
  114. app_name: str | None = None,
  115. ) -> AppHarness:
  116. """Create an AppHarness instance at root.
  117. Args:
  118. root: the directory that will contain the app under test.
  119. app_source: if specified, the source code from this function or module is used
  120. as the main module for the app. It may also be the raw source code text, as a str.
  121. If unspecified, then root must already contain a working reflex app and will be used directly.
  122. app_name: provide the name of the app, otherwise will be derived from app_source or root.
  123. Raises:
  124. ValueError: when app_source is a string and app_name is not provided.
  125. Returns:
  126. AppHarness instance
  127. """
  128. if app_name is None:
  129. if app_source is None:
  130. app_name = root.name
  131. elif isinstance(app_source, functools.partial):
  132. keywords = app_source.keywords
  133. slug_suffix = "_".join([str(v) for v in keywords.values()])
  134. func_name = app_source.func.__name__
  135. app_name = f"{func_name}_{slug_suffix}"
  136. app_name = re.sub(r"[^a-zA-Z0-9_]", "_", app_name)
  137. elif isinstance(app_source, str):
  138. msg = "app_name must be provided when app_source is a string."
  139. raise ValueError(msg)
  140. else:
  141. app_name = app_source.__name__
  142. app_name = app_name.lower()
  143. while "__" in app_name:
  144. app_name = app_name.replace("__", "_")
  145. return cls(
  146. app_name=app_name,
  147. app_source=app_source,
  148. app_path=root,
  149. app_module_path=root / app_name / f"{app_name}.py",
  150. )
  151. def get_state_name(self, state_cls_name: str) -> str:
  152. """Get the state name for the given state class name.
  153. Args:
  154. state_cls_name: The state class name
  155. Returns:
  156. The state name
  157. """
  158. return reflex.utils.format.to_snake_case(
  159. f"{self.app_name}___{self.app_name}___" + state_cls_name
  160. )
  161. def get_full_state_name(self, path: list[str]) -> str:
  162. """Get the full state name for the given state class name.
  163. Args:
  164. path: A list of state class names
  165. Returns:
  166. The full state name
  167. """
  168. # NOTE: using State.get_name() somehow causes trouble here
  169. # path = [State.get_name()] + [self.get_state_name(p) for p in path] # noqa: ERA001
  170. path = ["reflex___state____state"] + [self.get_state_name(p) for p in path]
  171. return ".".join(path)
  172. def _get_globals_from_signature(self, func: Any) -> dict[str, Any]:
  173. """Get the globals from a function or module object.
  174. Args:
  175. func: function or module object
  176. Returns:
  177. dict of globals
  178. """
  179. overrides = {}
  180. glbs = {}
  181. if not callable(func):
  182. return glbs
  183. if isinstance(func, functools.partial):
  184. overrides = func.keywords
  185. func = func.func
  186. for param in inspect.signature(func).parameters.values():
  187. if param.default is not inspect.Parameter.empty:
  188. glbs[param.name] = param.default
  189. glbs.update(overrides)
  190. return glbs
  191. def _get_source_from_app_source(self, app_source: Any) -> str:
  192. """Get the source from app_source.
  193. Args:
  194. app_source: function or module or str
  195. Returns:
  196. source code
  197. """
  198. if isinstance(app_source, str):
  199. return app_source
  200. source = inspect.getsource(app_source)
  201. source = re.sub(
  202. r"^\s*def\s+\w+\s*\(.*?\)(\s+->\s+\w+)?:", "", source, flags=re.DOTALL
  203. )
  204. return textwrap.dedent(source)
  205. def _initialize_app(self):
  206. # disable telemetry reporting for tests
  207. os.environ["TELEMETRY_ENABLED"] = "false"
  208. CustomComponent.create().get_component.cache_clear()
  209. self.app_path.mkdir(parents=True, exist_ok=True)
  210. if self.app_source is not None:
  211. app_globals = self._get_globals_from_signature(self.app_source)
  212. if isinstance(self.app_source, functools.partial):
  213. self.app_source = self.app_source.func
  214. # get the source from a function or module object
  215. source_code = "\n".join(
  216. [
  217. "\n".join(
  218. self.get_app_global_source(k, v) for k, v in app_globals.items()
  219. ),
  220. self._get_source_from_app_source(self.app_source),
  221. ]
  222. )
  223. get_config().loglevel = reflex.constants.LogLevel.INFO
  224. with chdir(self.app_path):
  225. reflex.reflex._init(
  226. name=self.app_name,
  227. template=reflex.constants.Templates.DEFAULT,
  228. )
  229. self.app_module_path.write_text(source_code)
  230. else:
  231. # Just initialize the web folder.
  232. with chdir(self.app_path):
  233. reflex.utils.prerequisites.initialize_frontend_dependencies()
  234. with chdir(self.app_path):
  235. # ensure config and app are reloaded when testing different app
  236. reflex.config.get_config(reload=True)
  237. # Ensure the AppHarness test does not skip State assignment due to running via pytest
  238. os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
  239. os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
  240. # Ensure we actually compile the app during first initialization.
  241. self.app_instance, self.app_module = (
  242. reflex.utils.prerequisites.get_and_validate_app(
  243. # Do not reload the module for pre-existing apps (only apps generated from source)
  244. reload=self.app_source is not None
  245. )
  246. )
  247. self.app_asgi = self.app_instance()
  248. if self.app_instance and isinstance(
  249. self.app_instance._state_manager, StateManagerRedis
  250. ):
  251. if self.app_instance._state is None:
  252. msg = "State is not set."
  253. raise RuntimeError(msg)
  254. # Create our own redis connection for testing.
  255. self.state_manager = StateManagerRedis.create(self.app_instance._state)
  256. else:
  257. self.state_manager = (
  258. self.app_instance._state_manager if self.app_instance else None
  259. )
  260. def _reload_state_module(self):
  261. """Reload the rx.State module to avoid conflict when reloading."""
  262. reload_state_module(module=f"{self.app_name}.{self.app_name}")
  263. def _get_backend_shutdown_handler(self):
  264. if self.backend is None:
  265. msg = "Backend was not initialized."
  266. raise RuntimeError(msg)
  267. original_shutdown = self.backend.shutdown
  268. async def _shutdown(*args, **kwargs) -> None:
  269. # ensure redis is closed before event loop
  270. if self.app_instance is not None and isinstance(
  271. self.app_instance._state_manager, StateManagerRedis
  272. ):
  273. with contextlib.suppress(ValueError):
  274. await self.app_instance._state_manager.close()
  275. # socketio shutdown handler
  276. if self.app_instance is not None and self.app_instance.sio is not None:
  277. with contextlib.suppress(TypeError):
  278. await self.app_instance.sio.shutdown()
  279. # sqlalchemy async engine shutdown handler
  280. try:
  281. async_engine = reflex.model.get_async_engine(None)
  282. except ValueError:
  283. pass
  284. else:
  285. await async_engine.dispose()
  286. await original_shutdown(*args, **kwargs)
  287. return _shutdown
  288. def _start_backend(self, port: int = 0):
  289. if self.app_asgi is None:
  290. msg = "App was not initialized."
  291. raise RuntimeError(msg)
  292. self.backend = uvicorn.Server(
  293. uvicorn.Config(
  294. app=self.app_asgi,
  295. host="127.0.0.1",
  296. port=port,
  297. )
  298. )
  299. self.backend.shutdown = self._get_backend_shutdown_handler()
  300. with chdir(self.app_path):
  301. self.backend_thread = threading.Thread(target=self.backend.run)
  302. self.backend_thread.start()
  303. async def _reset_backend_state_manager(self):
  304. """Reset the StateManagerRedis event loop affinity.
  305. This is necessary when the backend is restarted and the state manager is a
  306. StateManagerRedis instance.
  307. Raises:
  308. RuntimeError: when the state manager cannot be reset
  309. """
  310. if (
  311. self.app_instance is not None
  312. and isinstance(
  313. self.app_instance._state_manager,
  314. StateManagerRedis,
  315. )
  316. and self.app_instance._state is not None
  317. ):
  318. with contextlib.suppress(RuntimeError):
  319. await self.app_instance._state_manager.close()
  320. self.app_instance._state_manager = StateManagerRedis.create(
  321. state=self.app_instance._state,
  322. )
  323. if not isinstance(self.app_instance.state_manager, StateManagerRedis):
  324. msg = "Failed to reset state manager."
  325. raise RuntimeError(msg)
  326. def _start_frontend(self):
  327. # Set up the frontend.
  328. with chdir(self.app_path):
  329. config = reflex.config.get_config()
  330. config.api_url = "http://{}:{}".format(
  331. *self._poll_for_servers().getsockname(),
  332. )
  333. reflex.utils.build.setup_frontend(self.app_path)
  334. # Start the frontend.
  335. self.frontend_process = reflex.utils.processes.new_process(
  336. [
  337. *reflex.utils.prerequisites.get_js_package_executor(raise_on_none=True)[
  338. 0
  339. ],
  340. "run",
  341. "dev",
  342. ],
  343. cwd=self.app_path / reflex.utils.prerequisites.get_web_dir(),
  344. env={"PORT": "0"},
  345. **FRONTEND_POPEN_ARGS,
  346. )
  347. def _wait_frontend(self):
  348. while self.frontend_url is None:
  349. line = (
  350. self.frontend_process.stdout.readline() # pyright: ignore [reportOptionalMemberAccess]
  351. )
  352. if not line:
  353. break
  354. print(line) # for pytest diagnosis #noqa: T201
  355. m = re.search(reflex.constants.Next.FRONTEND_LISTENING_REGEX, line)
  356. if m is not None:
  357. self.frontend_url = m.group(1)
  358. config = reflex.config.get_config()
  359. config.deploy_url = self.frontend_url
  360. break
  361. if self.frontend_url is None:
  362. msg = "Frontend did not start"
  363. raise RuntimeError(msg)
  364. def consume_frontend_output():
  365. while True:
  366. try:
  367. line = (
  368. self.frontend_process.stdout.readline() # pyright: ignore [reportOptionalMemberAccess]
  369. )
  370. # catch I/O operation on closed file.
  371. except ValueError as e:
  372. console.error(str(e))
  373. break
  374. if not line:
  375. break
  376. self.frontend_output_thread = threading.Thread(target=consume_frontend_output)
  377. self.frontend_output_thread.start()
  378. def start(self) -> AppHarness:
  379. """Start the backend in a new thread and dev frontend as a separate process.
  380. Returns:
  381. self
  382. """
  383. self._initialize_app()
  384. self._start_backend()
  385. self._start_frontend()
  386. self._wait_frontend()
  387. return self
  388. @staticmethod
  389. def get_app_global_source(key: str, value: Any):
  390. """Get the source code of a global object.
  391. If value is a function or class we render the actual
  392. source of value otherwise we assign value to key.
  393. Args:
  394. key: variable name to assign value to.
  395. value: value of the global variable.
  396. Returns:
  397. The rendered app global code.
  398. """
  399. if not inspect.isclass(value) and not inspect.isfunction(value):
  400. return f"{key} = {value!r}"
  401. return inspect.getsource(value)
  402. def __enter__(self) -> AppHarness:
  403. """Contextmanager protocol for `start()`.
  404. Returns:
  405. Instance of AppHarness after calling start()
  406. """
  407. return self.start()
  408. def stop(self) -> None:
  409. """Stop the frontend and backend servers."""
  410. # Quit browsers first to avoid any lingering events being sent during shutdown.
  411. for driver in self._frontends:
  412. driver.quit()
  413. self._reload_state_module()
  414. if self.backend is not None:
  415. self.backend.should_exit = True
  416. if self.frontend_process is not None:
  417. # https://stackoverflow.com/a/70565806
  418. frontend_children = psutil.Process(self.frontend_process.pid).children(
  419. recursive=True,
  420. )
  421. if sys.platform == "win32":
  422. self.frontend_process.terminate()
  423. else:
  424. pgrp = os.getpgid(self.frontend_process.pid)
  425. os.killpg(pgrp, signal.SIGTERM)
  426. # kill any remaining child processes
  427. for child in frontend_children:
  428. # It's okay if the process is already gone.
  429. with contextlib.suppress(psutil.NoSuchProcess):
  430. child.terminate()
  431. _, still_alive = psutil.wait_procs(frontend_children, timeout=3)
  432. for child in still_alive:
  433. # It's okay if the process is already gone.
  434. with contextlib.suppress(psutil.NoSuchProcess):
  435. child.kill()
  436. # wait for main process to exit
  437. self.frontend_process.communicate()
  438. if self.backend_thread is not None:
  439. self.backend_thread.join()
  440. if self.frontend_output_thread is not None:
  441. self.frontend_output_thread.join()
  442. def __exit__(self, *excinfo) -> None:
  443. """Contextmanager protocol for `stop()`.
  444. Args:
  445. excinfo: sys.exc_info captured in the context block
  446. """
  447. self.stop()
  448. @staticmethod
  449. def _poll_for(
  450. target: Callable[[], T],
  451. timeout: TimeoutType = None,
  452. step: TimeoutType = None,
  453. ) -> T | bool:
  454. """Generic polling logic.
  455. Args:
  456. target: callable that returns truthy if polling condition is met.
  457. timeout: max polling time
  458. step: interval between checking target()
  459. Returns:
  460. return value of target() if truthy within timeout
  461. False if timeout elapses
  462. """
  463. if timeout is None:
  464. timeout = DEFAULT_TIMEOUT
  465. if step is None:
  466. step = POLL_INTERVAL
  467. deadline = time.time() + timeout
  468. while time.time() < deadline:
  469. success = target()
  470. if success:
  471. return success
  472. time.sleep(step)
  473. return False
  474. @staticmethod
  475. async def _poll_for_async(
  476. target: Callable[[], Coroutine[None, None, T]],
  477. timeout: TimeoutType = None,
  478. step: TimeoutType = None,
  479. ) -> T | bool:
  480. """Generic polling logic for async functions.
  481. Args:
  482. target: callable that returns truthy if polling condition is met.
  483. timeout: max polling time
  484. step: interval between checking target()
  485. Returns:
  486. return value of target() if truthy within timeout
  487. False if timeout elapses
  488. """
  489. if timeout is None:
  490. timeout = DEFAULT_TIMEOUT
  491. if step is None:
  492. step = POLL_INTERVAL
  493. deadline = time.time() + timeout
  494. while time.time() < deadline:
  495. success = await target()
  496. if success:
  497. return success
  498. await asyncio.sleep(step)
  499. return False
  500. def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
  501. """Poll backend server for listening sockets.
  502. Args:
  503. timeout: how long to wait for listening socket.
  504. Returns:
  505. first active listening socket on the backend
  506. Raises:
  507. RuntimeError: when the backend hasn't started running
  508. TimeoutError: when server or sockets are not ready
  509. """
  510. if self.backend is None:
  511. msg = "Backend is not running."
  512. raise RuntimeError(msg)
  513. backend = self.backend
  514. # check for servers to be initialized
  515. if not self._poll_for(
  516. target=lambda: getattr(backend, "servers", False),
  517. timeout=timeout,
  518. ):
  519. msg = "Backend servers are not initialized."
  520. raise TimeoutError(msg)
  521. # check for sockets to be listening
  522. if not self._poll_for(
  523. target=lambda: getattr(backend.servers[0], "sockets", False),
  524. timeout=timeout,
  525. ):
  526. msg = "Backend is not listening."
  527. raise TimeoutError(msg)
  528. return backend.servers[0].sockets[0]
  529. def frontend(
  530. self,
  531. driver_clz: type[WebDriver] | None = None,
  532. driver_kwargs: dict[str, Any] | None = None,
  533. driver_options: ArgOptions | None = None,
  534. driver_option_args: list[str] | None = None,
  535. driver_option_capabilities: dict[str, Any] | None = None,
  536. ) -> WebDriver:
  537. """Get a selenium webdriver instance pointed at the app.
  538. Args:
  539. driver_clz: webdriver.Chrome (default), webdriver.Firefox, webdriver.Safari,
  540. webdriver.Edge, etc
  541. driver_kwargs: additional keyword arguments to pass to the webdriver constructor
  542. driver_options: selenium ArgOptions instance to pass to the webdriver constructor
  543. driver_option_args: additional arguments for the webdriver options
  544. driver_option_capabilities: additional capabilities for the webdriver options
  545. Returns:
  546. Instance of the given webdriver navigated to the frontend url of the app.
  547. Raises:
  548. RuntimeError: when selenium is not importable or frontend is not running
  549. """
  550. if not has_selenium:
  551. msg = (
  552. "Frontend functionality requires `selenium` to be installed, "
  553. "and it could not be imported."
  554. )
  555. raise RuntimeError(msg)
  556. if self.frontend_url is None:
  557. msg = "Frontend is not running."
  558. raise RuntimeError(msg)
  559. want_headless = False
  560. if environment.APP_HARNESS_HEADLESS.get():
  561. want_headless = True
  562. if driver_clz is None:
  563. requested_driver = environment.APP_HARNESS_DRIVER.get()
  564. driver_clz = getattr(webdriver, requested_driver) # pyright: ignore [reportPossiblyUnboundVariable]
  565. if driver_options is None:
  566. driver_options = getattr(webdriver, f"{requested_driver}Options")() # pyright: ignore [reportPossiblyUnboundVariable]
  567. if driver_clz is webdriver.Chrome: # pyright: ignore [reportPossiblyUnboundVariable]
  568. if driver_options is None:
  569. driver_options = webdriver.ChromeOptions() # pyright: ignore [reportPossiblyUnboundVariable]
  570. driver_options.add_argument("--class=AppHarness")
  571. if want_headless:
  572. driver_options.add_argument("--headless=new")
  573. elif driver_clz is webdriver.Firefox: # pyright: ignore [reportPossiblyUnboundVariable]
  574. if driver_options is None:
  575. driver_options = webdriver.FirefoxOptions() # pyright: ignore [reportPossiblyUnboundVariable]
  576. if want_headless:
  577. driver_options.add_argument("-headless")
  578. elif driver_clz is webdriver.Edge: # pyright: ignore [reportPossiblyUnboundVariable]
  579. if driver_options is None:
  580. driver_options = webdriver.EdgeOptions() # pyright: ignore [reportPossiblyUnboundVariable]
  581. if want_headless:
  582. driver_options.add_argument("headless")
  583. if driver_options is None:
  584. msg = f"Could not determine options for {driver_clz}"
  585. raise RuntimeError(msg)
  586. if args := environment.APP_HARNESS_DRIVER_ARGS.get():
  587. for arg in args.split(","):
  588. driver_options.add_argument(arg)
  589. if driver_option_args is not None:
  590. for arg in driver_option_args:
  591. driver_options.add_argument(arg)
  592. if driver_option_capabilities is not None:
  593. for key, value in driver_option_capabilities.items():
  594. driver_options.set_capability(key, value)
  595. if driver_kwargs is None:
  596. driver_kwargs = {}
  597. driver = driver_clz(options=driver_options, **driver_kwargs) # pyright: ignore [reportOptionalCall, reportArgumentType]
  598. driver.get(self.frontend_url)
  599. self._frontends.append(driver)
  600. return driver
  601. async def get_state(self, token: str) -> BaseState:
  602. """Get the state associated with the given token.
  603. Args:
  604. token: The state token to look up.
  605. Returns:
  606. The state instance associated with the given token
  607. Raises:
  608. RuntimeError: when the app hasn't started running
  609. """
  610. if self.state_manager is None:
  611. msg = "state_manager is not set."
  612. raise RuntimeError(msg)
  613. try:
  614. return await self.state_manager.get_state(token)
  615. finally:
  616. if isinstance(self.state_manager, StateManagerRedis):
  617. await self.state_manager.close()
  618. async def set_state(self, token: str, **kwargs) -> None:
  619. """Set the state associated with the given token.
  620. Args:
  621. token: The state token to set.
  622. kwargs: Attributes to set on the state.
  623. Raises:
  624. RuntimeError: when the app hasn't started running
  625. """
  626. if self.state_manager is None:
  627. msg = "state_manager is not set."
  628. raise RuntimeError(msg)
  629. state = await self.get_state(token)
  630. for key, value in kwargs.items():
  631. setattr(state, key, value)
  632. try:
  633. await self.state_manager.set_state(token, state)
  634. finally:
  635. if isinstance(self.state_manager, StateManagerRedis):
  636. await self.state_manager.close()
  637. @contextlib.asynccontextmanager
  638. async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
  639. """Modify the state associated with the given token and send update to frontend.
  640. Args:
  641. token: The state token to modify
  642. Yields:
  643. The state instance associated with the given token
  644. Raises:
  645. RuntimeError: when the app hasn't started running
  646. """
  647. if self.state_manager is None:
  648. msg = "state_manager is not set."
  649. raise RuntimeError(msg)
  650. if self.app_instance is None:
  651. msg = "App is not running."
  652. raise RuntimeError(msg)
  653. app_state_manager = self.app_instance.state_manager
  654. if isinstance(self.state_manager, StateManagerRedis):
  655. # Temporarily replace the app's state manager with our own, since
  656. # the redis connection is on the backend_thread event loop
  657. self.app_instance._state_manager = self.state_manager
  658. try:
  659. async with self.app_instance.modify_state(token) as state:
  660. yield state
  661. finally:
  662. if isinstance(self.state_manager, StateManagerRedis):
  663. self.app_instance._state_manager = app_state_manager
  664. await self.state_manager.close()
  665. def poll_for_content(
  666. self,
  667. element: WebElement,
  668. timeout: TimeoutType = None,
  669. exp_not_equal: str = "",
  670. ) -> str:
  671. """Poll element.text for change.
  672. Args:
  673. element: selenium webdriver element to check
  674. timeout: how long to poll element.text
  675. exp_not_equal: exit the polling loop when the element text does not match
  676. Returns:
  677. The element text when the polling loop exited
  678. Raises:
  679. TimeoutError: when the timeout expires before text changes
  680. """
  681. if not self._poll_for(
  682. target=lambda: element.text != exp_not_equal,
  683. timeout=timeout,
  684. ):
  685. msg = f"{element} content remains {exp_not_equal!r} while polling."
  686. raise TimeoutError(msg)
  687. return element.text
  688. def poll_for_value(
  689. self,
  690. element: WebElement,
  691. timeout: TimeoutType = None,
  692. exp_not_equal: str | Sequence[str] = "",
  693. ) -> str | None:
  694. """Poll element.get_attribute("value") for change.
  695. Args:
  696. element: selenium webdriver element to check
  697. timeout: how long to poll element value attribute
  698. exp_not_equal: exit the polling loop when the value does not match
  699. Returns:
  700. The element value when the polling loop exited
  701. Raises:
  702. TimeoutError: when the timeout expires before value changes
  703. """
  704. exp_not_equal = (
  705. (exp_not_equal,) if isinstance(exp_not_equal, str) else exp_not_equal
  706. )
  707. if not self._poll_for(
  708. target=lambda: element.get_attribute("value") not in exp_not_equal,
  709. timeout=timeout,
  710. ):
  711. msg = f"{element} content remains {exp_not_equal!r} while polling."
  712. raise TimeoutError(msg)
  713. return element.get_attribute("value")
  714. def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
  715. """Poll app state_manager for any connected clients.
  716. Args:
  717. timeout: how long to wait for client states
  718. Returns:
  719. active state instances when the polling loop exited
  720. Raises:
  721. RuntimeError: when the app hasn't started running
  722. TimeoutError: when the timeout expires before any states are seen
  723. ValueError: when the state_manager is not a memory state manager
  724. """
  725. if self.app_instance is None:
  726. msg = "App is not running."
  727. raise RuntimeError(msg)
  728. state_manager = self.app_instance.state_manager
  729. if not isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
  730. msg = "Only works with memory or disk state manager"
  731. raise ValueError(msg)
  732. if not self._poll_for(
  733. target=lambda: state_manager.states,
  734. timeout=timeout,
  735. ):
  736. msg = "No states were observed while polling."
  737. raise TimeoutError(msg)
  738. return state_manager.states
  739. class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler):
  740. """SimpleHTTPRequestHandler with custom error page handling."""
  741. def __init__(self, *args, error_page_map: dict[int, Path], **kwargs):
  742. """Initialize the handler.
  743. Args:
  744. error_page_map: map of error code to error page path
  745. *args: passed through to superclass
  746. **kwargs: passed through to superclass
  747. """
  748. self.error_page_map = error_page_map
  749. super().__init__(*args, **kwargs)
  750. def send_error(
  751. self, code: int, message: str | None = None, explain: str | None = None
  752. ) -> None:
  753. """Send the error page for the given error code.
  754. If the code matches a custom error page, then message and explain are
  755. ignored.
  756. Args:
  757. code: the error code
  758. message: the error message
  759. explain: the error explanation
  760. """
  761. error_page = self.error_page_map.get(code)
  762. if error_page:
  763. self.send_response(code, message)
  764. self.send_header("Connection", "close")
  765. body = error_page.read_bytes()
  766. self.send_header("Content-Type", self.error_content_type)
  767. self.send_header("Content-Length", str(len(body)))
  768. self.end_headers()
  769. self.wfile.write(body)
  770. else:
  771. super().send_error(code, message, explain)
  772. class Subdir404TCPServer(socketserver.TCPServer):
  773. """TCPServer for SimpleHTTPRequestHandlerCustomErrors that serves from a subdir."""
  774. def __init__(
  775. self,
  776. *args,
  777. root: Path,
  778. error_page_map: dict[int, Path] | None,
  779. **kwargs,
  780. ):
  781. """Initialize the server.
  782. Args:
  783. root: the root directory to serve from
  784. error_page_map: map of error code to error page path
  785. *args: passed through to superclass
  786. **kwargs: passed through to superclass
  787. """
  788. self.root = root
  789. self.error_page_map = error_page_map or {}
  790. super().__init__(*args, **kwargs)
  791. def finish_request(self, request: socket.socket, client_address: tuple[str, int]):
  792. """Finish one request by instantiating RequestHandlerClass.
  793. Args:
  794. request: the requesting socket
  795. client_address: (host, port) referring to the client's address.
  796. """
  797. self.RequestHandlerClass(
  798. request,
  799. client_address,
  800. self,
  801. directory=str(self.root), # pyright: ignore [reportCallIssue]
  802. error_page_map=self.error_page_map, # pyright: ignore [reportCallIssue]
  803. )
  804. class AppHarnessProd(AppHarness):
  805. """AppHarnessProd executes a reflex app in-process for testing.
  806. In prod mode, instead of running `next dev` the app is exported as static
  807. files and served via the builtin python http.server with custom 404 redirect
  808. handling. Additionally, the backend runs in multi-worker mode.
  809. """
  810. frontend_thread: threading.Thread | None = None
  811. frontend_server: Subdir404TCPServer | None = None
  812. def _run_frontend(self):
  813. web_root = (
  814. self.app_path
  815. / reflex.utils.prerequisites.get_web_dir()
  816. / reflex.constants.Dirs.STATIC
  817. )
  818. error_page_map = {
  819. 404: web_root / "404.html",
  820. }
  821. with Subdir404TCPServer(
  822. ("", 0),
  823. SimpleHTTPRequestHandlerCustomErrors,
  824. root=web_root,
  825. error_page_map=error_page_map,
  826. ) as self.frontend_server:
  827. self.frontend_url = "http://localhost:{1}".format(
  828. *self.frontend_server.socket.getsockname()
  829. )
  830. self.frontend_server.serve_forever()
  831. def _start_frontend(self):
  832. # Set up the frontend.
  833. with chdir(self.app_path):
  834. config = reflex.config.get_config()
  835. config.api_url = "http://{}:{}".format(
  836. *self._poll_for_servers().getsockname(),
  837. )
  838. get_config().loglevel = reflex.constants.LogLevel.INFO
  839. reflex.utils.prerequisites.assert_in_reflex_dir()
  840. if reflex.utils.prerequisites.needs_reinit():
  841. reflex.reflex._init(name=get_config().app_name)
  842. export(
  843. zipping=False,
  844. frontend=True,
  845. backend=False,
  846. loglevel=reflex.constants.LogLevel.INFO,
  847. env=reflex.constants.Env.PROD,
  848. )
  849. self.frontend_thread = threading.Thread(target=self._run_frontend)
  850. self.frontend_thread.start()
  851. def _wait_frontend(self):
  852. self._poll_for(lambda: self.frontend_server is not None)
  853. if self.frontend_server is None or not self.frontend_server.socket.fileno():
  854. msg = "Frontend did not start"
  855. raise RuntimeError(msg)
  856. def _start_backend(self):
  857. if self.app_asgi is None:
  858. msg = "App was not initialized."
  859. raise RuntimeError(msg)
  860. environment.REFLEX_SKIP_COMPILE.set(True)
  861. self.backend = uvicorn.Server(
  862. uvicorn.Config(
  863. app=self.app_asgi,
  864. host="127.0.0.1",
  865. port=0,
  866. workers=reflex.utils.processes.get_num_workers(),
  867. ),
  868. )
  869. self.backend.shutdown = self._get_backend_shutdown_handler()
  870. self.backend_thread = threading.Thread(target=self.backend.run)
  871. self.backend_thread.start()
  872. def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
  873. try:
  874. return super()._poll_for_servers(timeout)
  875. finally:
  876. environment.REFLEX_SKIP_COMPILE.set(None)
  877. def stop(self):
  878. """Stop the frontend python webserver."""
  879. super().stop()
  880. if self.frontend_server is not None:
  881. self.frontend_server.shutdown()
  882. if self.frontend_thread is not None:
  883. self.frontend_thread.join()