testing.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959
  1. """reflex.testing - tools for testing reflex apps."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import dataclasses
  6. import functools
  7. import inspect
  8. import os
  9. import pathlib
  10. import platform
  11. import re
  12. import signal
  13. import socket
  14. import socketserver
  15. import subprocess
  16. import textwrap
  17. import threading
  18. import time
  19. import types
  20. from http.server import SimpleHTTPRequestHandler
  21. from typing import (
  22. TYPE_CHECKING,
  23. Any,
  24. AsyncIterator,
  25. Callable,
  26. Coroutine,
  27. List,
  28. Optional,
  29. Type,
  30. TypeVar,
  31. Union,
  32. )
  33. import psutil
  34. import uvicorn
  35. import reflex
  36. import reflex.reflex
  37. import reflex.utils.build
  38. import reflex.utils.exec
  39. import reflex.utils.format
  40. import reflex.utils.prerequisites
  41. import reflex.utils.processes
  42. from reflex.state import (
  43. BaseState,
  44. StateManagerMemory,
  45. StateManagerRedis,
  46. reload_state_module,
  47. )
  48. try:
  49. from selenium import webdriver # pyright: ignore [reportMissingImports]
  50. from selenium.webdriver.remote.webdriver import ( # pyright: ignore [reportMissingImports]
  51. WebDriver,
  52. )
  53. if TYPE_CHECKING:
  54. from selenium.webdriver.common.options import (
  55. ArgOptions, # pyright: ignore [reportMissingImports]
  56. )
  57. from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
  58. WebElement,
  59. )
  60. has_selenium = True
  61. except ImportError:
  62. has_selenium = False
  63. # The timeout (minutes) to check for the port.
  64. DEFAULT_TIMEOUT = 15
  65. POLL_INTERVAL = 0.25
  66. FRONTEND_POPEN_ARGS = {}
  67. T = TypeVar("T")
  68. TimeoutType = Optional[Union[int, float]]
  69. if platform.system() == "Windows":
  70. FRONTEND_POPEN_ARGS["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore
  71. FRONTEND_POPEN_ARGS["shell"] = True
  72. else:
  73. FRONTEND_POPEN_ARGS["start_new_session"] = True
  74. # borrowed from py3.11
  75. class chdir(contextlib.AbstractContextManager):
  76. """Non thread-safe context manager to change the current working directory."""
  77. def __init__(self, path):
  78. """Prepare contextmanager.
  79. Args:
  80. path: the path to change to
  81. """
  82. self.path = path
  83. self._old_cwd = []
  84. def __enter__(self):
  85. """Save current directory and perform chdir."""
  86. self._old_cwd.append(os.getcwd())
  87. os.chdir(self.path)
  88. def __exit__(self, *excinfo):
  89. """Change back to previous directory on stack.
  90. Args:
  91. excinfo: sys.exc_info captured in the context block
  92. """
  93. os.chdir(self._old_cwd.pop())
  94. @dataclasses.dataclass
  95. class AppHarness:
  96. """AppHarness executes a reflex app in-process for testing."""
  97. app_name: str
  98. app_source: Optional[
  99. types.FunctionType | types.ModuleType | str | functools.partial[Any]
  100. ]
  101. app_path: pathlib.Path
  102. app_module_path: pathlib.Path
  103. app_module: Optional[types.ModuleType] = None
  104. app_instance: Optional[reflex.App] = None
  105. frontend_process: Optional[subprocess.Popen] = None
  106. frontend_url: Optional[str] = None
  107. frontend_output_thread: Optional[threading.Thread] = None
  108. backend_thread: Optional[threading.Thread] = None
  109. backend: Optional[uvicorn.Server] = None
  110. state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
  111. _frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
  112. _decorated_pages: list = dataclasses.field(default_factory=list)
  113. @classmethod
  114. def create(
  115. cls,
  116. root: pathlib.Path,
  117. app_source: Optional[
  118. types.FunctionType | types.ModuleType | str | functools.partial[Any]
  119. ] = None,
  120. app_name: Optional[str] = None,
  121. ) -> "AppHarness":
  122. """Create an AppHarness instance at root.
  123. Args:
  124. root: the directory that will contain the app under test.
  125. app_source: if specified, the source code from this function or module is used
  126. as the main module for the app. It may also be the raw source code text, as a str.
  127. If unspecified, then root must already contain a working reflex app and will be used directly.
  128. app_name: provide the name of the app, otherwise will be derived from app_source or root.
  129. Raises:
  130. ValueError: when app_source is a string and app_name is not provided.
  131. Returns:
  132. AppHarness instance
  133. """
  134. if app_name is None:
  135. if app_source is None:
  136. app_name = root.name
  137. elif isinstance(app_source, functools.partial):
  138. keywords = app_source.keywords
  139. slug_suffix = "_".join([str(v) for v in keywords.values()])
  140. func_name = app_source.func.__name__
  141. app_name = f"{func_name}_{slug_suffix}"
  142. app_name = re.sub(r"[^a-zA-Z0-9_]", "_", app_name)
  143. elif isinstance(app_source, str):
  144. raise ValueError(
  145. "app_name must be provided when app_source is a string."
  146. )
  147. else:
  148. app_name = app_source.__name__
  149. app_name = app_name.lower()
  150. while "__" in app_name:
  151. app_name = app_name.replace("__", "_")
  152. return cls(
  153. app_name=app_name,
  154. app_source=app_source,
  155. app_path=root,
  156. app_module_path=root / app_name / f"{app_name}.py",
  157. )
  158. def get_state_name(self, state_cls_name: str) -> str:
  159. """Get the state name for the given state class name.
  160. Args:
  161. state_cls_name: The state class name
  162. Returns:
  163. The state name
  164. """
  165. return reflex.utils.format.to_snake_case(
  166. f"{self.app_name}___{self.app_name}___" + state_cls_name
  167. )
  168. def get_full_state_name(self, path: List[str]) -> str:
  169. """Get the full state name for the given state class name.
  170. Args:
  171. path: A list of state class names
  172. Returns:
  173. The full state name
  174. """
  175. # NOTE: using State.get_name() somehow causes trouble here
  176. # path = [State.get_name()] + [self.get_state_name(p) for p in path]
  177. path = ["reflex___state____state"] + [self.get_state_name(p) for p in path]
  178. return ".".join(path)
  179. def _get_globals_from_signature(self, func: Any) -> dict[str, Any]:
  180. """Get the globals from a function or module object.
  181. Args:
  182. func: function or module object
  183. Returns:
  184. dict of globals
  185. """
  186. overrides = {}
  187. glbs = {}
  188. if not callable(func):
  189. return glbs
  190. if isinstance(func, functools.partial):
  191. overrides = func.keywords
  192. func = func.func
  193. for param in inspect.signature(func).parameters.values():
  194. if param.default is not inspect.Parameter.empty:
  195. glbs[param.name] = param.default
  196. glbs.update(overrides)
  197. return glbs
  198. def _get_source_from_app_source(self, app_source: Any) -> str:
  199. """Get the source from app_source.
  200. Args:
  201. app_source: function or module or str
  202. Returns:
  203. source code
  204. """
  205. if isinstance(app_source, str):
  206. return app_source
  207. source = inspect.getsource(app_source)
  208. source = re.sub(
  209. r"^\s*def\s+\w+\s*\(.*?\)(\s+->\s+\w+)?:", "", source, flags=re.DOTALL
  210. )
  211. return textwrap.dedent(source)
  212. def _initialize_app(self):
  213. os.environ["TELEMETRY_ENABLED"] = "" # disable telemetry reporting for tests
  214. self.app_path.mkdir(parents=True, exist_ok=True)
  215. if self.app_source is not None:
  216. app_globals = self._get_globals_from_signature(self.app_source)
  217. if isinstance(self.app_source, functools.partial):
  218. self.app_source = self.app_source.func # type: ignore
  219. # get the source from a function or module object
  220. source_code = "\n".join(
  221. [
  222. "\n".join(
  223. self.get_app_global_source(k, v) for k, v in app_globals.items()
  224. ),
  225. self._get_source_from_app_source(self.app_source),
  226. ]
  227. )
  228. with chdir(self.app_path):
  229. reflex.reflex._init(
  230. name=self.app_name,
  231. template=reflex.constants.Templates.DEFAULT,
  232. loglevel=reflex.constants.LogLevel.INFO,
  233. )
  234. self.app_module_path.write_text(source_code)
  235. with chdir(self.app_path):
  236. # ensure config and app are reloaded when testing different app
  237. reflex.config.get_config(reload=True)
  238. # Save decorated pages before importing the test app module
  239. before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy()
  240. # Ensure the AppHarness test does not skip State assignment due to running via pytest
  241. os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
  242. self.app_module = reflex.utils.prerequisites.get_compiled_app(
  243. # Do not reload the module for pre-existing apps (only apps generated from source)
  244. reload=self.app_source is not None
  245. )
  246. # Save the pages that were added during testing
  247. self._decorated_pages = [
  248. p
  249. for p in reflex.app.DECORATED_PAGES[self.app_name]
  250. if p not in before_decorated_pages
  251. ]
  252. self.app_instance = self.app_module.app
  253. if isinstance(self.app_instance._state_manager, StateManagerRedis):
  254. # Create our own redis connection for testing.
  255. self.state_manager = StateManagerRedis.create(self.app_instance.state)
  256. else:
  257. self.state_manager = self.app_instance._state_manager
  258. def _reload_state_module(self):
  259. """Reload the rx.State module to avoid conflict when reloading."""
  260. reload_state_module(module=f"{self.app_name}.{self.app_name}")
  261. def _get_backend_shutdown_handler(self):
  262. if self.backend is None:
  263. raise RuntimeError("Backend was not initialized.")
  264. original_shutdown = self.backend.shutdown
  265. async def _shutdown_redis(*args, **kwargs) -> None:
  266. # ensure redis is closed before event loop
  267. try:
  268. if self.app_instance is not None and isinstance(
  269. self.app_instance.state_manager, StateManagerRedis
  270. ):
  271. await self.app_instance.state_manager.close()
  272. except ValueError:
  273. pass
  274. await original_shutdown(*args, **kwargs)
  275. return _shutdown_redis
  276. def _start_backend(self, port=0):
  277. if self.app_instance is None:
  278. raise RuntimeError("App was not initialized.")
  279. self.backend = uvicorn.Server(
  280. uvicorn.Config(
  281. app=self.app_instance.api,
  282. host="127.0.0.1",
  283. port=port,
  284. )
  285. )
  286. self.backend.shutdown = self._get_backend_shutdown_handler()
  287. self.backend_thread = threading.Thread(target=self.backend.run)
  288. self.backend_thread.start()
  289. async def _reset_backend_state_manager(self):
  290. """Reset the StateManagerRedis event loop affinity.
  291. This is necessary when the backend is restarted and the state manager is a
  292. StateManagerRedis instance.
  293. """
  294. if (
  295. self.app_instance is not None
  296. and isinstance(
  297. self.app_instance.state_manager,
  298. StateManagerRedis,
  299. )
  300. and self.app_instance.state is not None
  301. ):
  302. with contextlib.suppress(RuntimeError):
  303. await self.app_instance.state_manager.close()
  304. self.app_instance._state_manager = StateManagerRedis.create(
  305. state=self.app_instance.state,
  306. )
  307. assert isinstance(self.app_instance.state_manager, StateManagerRedis)
  308. def _start_frontend(self):
  309. # Set up the frontend.
  310. with chdir(self.app_path):
  311. config = reflex.config.get_config()
  312. config.api_url = "http://{0}:{1}".format(
  313. *self._poll_for_servers().getsockname(),
  314. )
  315. reflex.utils.build.setup_frontend(self.app_path)
  316. # Start the frontend.
  317. self.frontend_process = reflex.utils.processes.new_process(
  318. [reflex.utils.prerequisites.get_package_manager(), "run", "dev"],
  319. cwd=self.app_path / reflex.utils.prerequisites.get_web_dir(),
  320. env={"PORT": "0"},
  321. **FRONTEND_POPEN_ARGS,
  322. )
  323. def _wait_frontend(self):
  324. while self.frontend_url is None:
  325. line = (
  326. self.frontend_process.stdout.readline() # pyright: ignore [reportOptionalMemberAccess]
  327. )
  328. if not line:
  329. break
  330. print(line) # for pytest diagnosis
  331. m = re.search(reflex.constants.Next.FRONTEND_LISTENING_REGEX, line)
  332. if m is not None:
  333. self.frontend_url = m.group(1)
  334. config = reflex.config.get_config()
  335. config.deploy_url = self.frontend_url
  336. break
  337. if self.frontend_url is None:
  338. raise RuntimeError("Frontend did not start")
  339. def consume_frontend_output():
  340. while True:
  341. line = (
  342. self.frontend_process.stdout.readline() # pyright: ignore [reportOptionalMemberAccess]
  343. )
  344. if not line:
  345. break
  346. print(line)
  347. self.frontend_output_thread = threading.Thread(target=consume_frontend_output)
  348. self.frontend_output_thread.start()
  349. def start(self) -> "AppHarness":
  350. """Start the backend in a new thread and dev frontend as a separate process.
  351. Returns:
  352. self
  353. """
  354. self._initialize_app()
  355. self._start_backend()
  356. self._start_frontend()
  357. self._wait_frontend()
  358. return self
  359. @staticmethod
  360. def get_app_global_source(key, value):
  361. """Get the source code of a global object.
  362. If value is a function or class we render the actual
  363. source of value otherwise we assign value to key.
  364. Args:
  365. key: variable name to assign value to.
  366. value: value of the global variable.
  367. Returns:
  368. The rendered app global code.
  369. """
  370. if not inspect.isclass(value) and not inspect.isfunction(value):
  371. return f"{key} = {value!r}"
  372. return inspect.getsource(value)
  373. def __enter__(self) -> "AppHarness":
  374. """Contextmanager protocol for `start()`.
  375. Returns:
  376. Instance of AppHarness after calling start()
  377. """
  378. return self.start()
  379. def stop(self) -> None:
  380. """Stop the frontend and backend servers."""
  381. self._reload_state_module()
  382. if self.backend is not None:
  383. self.backend.should_exit = True
  384. if self.frontend_process is not None:
  385. # https://stackoverflow.com/a/70565806
  386. frontend_children = psutil.Process(self.frontend_process.pid).children(
  387. recursive=True,
  388. )
  389. if platform.system() == "Windows":
  390. self.frontend_process.terminate()
  391. else:
  392. pgrp = os.getpgid(self.frontend_process.pid)
  393. os.killpg(pgrp, signal.SIGTERM)
  394. # kill any remaining child processes
  395. for child in frontend_children:
  396. # It's okay if the process is already gone.
  397. with contextlib.suppress(psutil.NoSuchProcess):
  398. child.terminate()
  399. _, still_alive = psutil.wait_procs(frontend_children, timeout=3)
  400. for child in still_alive:
  401. # It's okay if the process is already gone.
  402. with contextlib.suppress(psutil.NoSuchProcess):
  403. child.kill()
  404. # wait for main process to exit
  405. self.frontend_process.communicate()
  406. if self.backend_thread is not None:
  407. self.backend_thread.join()
  408. if self.frontend_output_thread is not None:
  409. self.frontend_output_thread.join()
  410. for driver in self._frontends:
  411. driver.quit()
  412. # Cleanup decorated pages added during testing
  413. for page in self._decorated_pages:
  414. reflex.app.DECORATED_PAGES[self.app_name].remove(page)
  415. def __exit__(self, *excinfo) -> None:
  416. """Contextmanager protocol for `stop()`.
  417. Args:
  418. excinfo: sys.exc_info captured in the context block
  419. """
  420. self.stop()
  421. @staticmethod
  422. def _poll_for(
  423. target: Callable[[], T],
  424. timeout: TimeoutType = None,
  425. step: TimeoutType = None,
  426. ) -> T | bool:
  427. """Generic polling logic.
  428. Args:
  429. target: callable that returns truthy if polling condition is met.
  430. timeout: max polling time
  431. step: interval between checking target()
  432. Returns:
  433. return value of target() if truthy within timeout
  434. False if timeout elapses
  435. """
  436. if timeout is None:
  437. timeout = DEFAULT_TIMEOUT
  438. if step is None:
  439. step = POLL_INTERVAL
  440. deadline = time.time() + timeout
  441. while time.time() < deadline:
  442. success = target()
  443. if success:
  444. return success
  445. time.sleep(step)
  446. return False
  447. @staticmethod
  448. async def _poll_for_async(
  449. target: Callable[[], Coroutine[None, None, T]],
  450. timeout: TimeoutType = None,
  451. step: TimeoutType = None,
  452. ) -> T | bool:
  453. """Generic polling logic for async functions.
  454. Args:
  455. target: callable that returns truthy if polling condition is met.
  456. timeout: max polling time
  457. step: interval between checking target()
  458. Returns:
  459. return value of target() if truthy within timeout
  460. False if timeout elapses
  461. """
  462. if timeout is None:
  463. timeout = DEFAULT_TIMEOUT
  464. if step is None:
  465. step = POLL_INTERVAL
  466. deadline = time.time() + timeout
  467. while time.time() < deadline:
  468. success = await target()
  469. if success:
  470. return success
  471. await asyncio.sleep(step)
  472. return False
  473. def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
  474. """Poll backend server for listening sockets.
  475. Args:
  476. timeout: how long to wait for listening socket.
  477. Returns:
  478. first active listening socket on the backend
  479. Raises:
  480. RuntimeError: when the backend hasn't started running
  481. TimeoutError: when server or sockets are not ready
  482. """
  483. if self.backend is None:
  484. raise RuntimeError("Backend is not running.")
  485. backend = self.backend
  486. # check for servers to be initialized
  487. if not self._poll_for(
  488. target=lambda: getattr(backend, "servers", False),
  489. timeout=timeout,
  490. ):
  491. raise TimeoutError("Backend servers are not initialized.")
  492. # check for sockets to be listening
  493. if not self._poll_for(
  494. target=lambda: getattr(backend.servers[0], "sockets", False),
  495. timeout=timeout,
  496. ):
  497. raise TimeoutError("Backend is not listening.")
  498. return backend.servers[0].sockets[0]
  499. def frontend(
  500. self,
  501. driver_clz: Optional[Type["WebDriver"]] = None,
  502. driver_kwargs: dict[str, Any] | None = None,
  503. driver_options: ArgOptions | None = None,
  504. driver_option_args: List[str] | None = None,
  505. driver_option_capabilities: dict[str, Any] | None = None,
  506. ) -> "WebDriver":
  507. """Get a selenium webdriver instance pointed at the app.
  508. Args:
  509. driver_clz: webdriver.Chrome (default), webdriver.Firefox, webdriver.Safari,
  510. webdriver.Edge, etc
  511. driver_kwargs: additional keyword arguments to pass to the webdriver constructor
  512. driver_options: selenium ArgOptions instance to pass to the webdriver constructor
  513. driver_option_args: additional arguments for the webdriver options
  514. driver_option_capabilities: additional capabilities for the webdriver options
  515. Returns:
  516. Instance of the given webdriver navigated to the frontend url of the app.
  517. Raises:
  518. RuntimeError: when selenium is not importable or frontend is not running
  519. """
  520. if not has_selenium:
  521. raise RuntimeError(
  522. "Frontend functionality requires `selenium` to be installed, "
  523. "and it could not be imported."
  524. )
  525. if self.frontend_url is None:
  526. raise RuntimeError("Frontend is not running.")
  527. want_headless = False
  528. if os.environ.get("APP_HARNESS_HEADLESS"):
  529. want_headless = True
  530. if driver_clz is None:
  531. requested_driver = os.environ.get("APP_HARNESS_DRIVER", "Chrome")
  532. driver_clz = getattr(webdriver, requested_driver)
  533. if driver_options is None:
  534. driver_options = getattr(webdriver, f"{requested_driver}Options")()
  535. if driver_clz is webdriver.Chrome:
  536. if driver_options is None:
  537. driver_options = webdriver.ChromeOptions()
  538. driver_options.add_argument("--class=AppHarness")
  539. if want_headless:
  540. driver_options.add_argument("--headless=new")
  541. elif driver_clz is webdriver.Firefox:
  542. if driver_options is None:
  543. driver_options = webdriver.FirefoxOptions()
  544. if want_headless:
  545. driver_options.add_argument("-headless")
  546. elif driver_clz is webdriver.Edge:
  547. if driver_options is None:
  548. driver_options = webdriver.EdgeOptions()
  549. if want_headless:
  550. driver_options.add_argument("headless")
  551. if driver_options is None:
  552. raise RuntimeError(f"Could not determine options for {driver_clz}")
  553. if args := os.environ.get("APP_HARNESS_DRIVER_ARGS"):
  554. for arg in args.split(","):
  555. driver_options.add_argument(arg)
  556. if driver_option_args is not None:
  557. for arg in driver_option_args:
  558. driver_options.add_argument(arg)
  559. if driver_option_capabilities is not None:
  560. for key, value in driver_option_capabilities.items():
  561. driver_options.set_capability(key, value)
  562. if driver_kwargs is None:
  563. driver_kwargs = {}
  564. driver = driver_clz(options=driver_options, **driver_kwargs) # type: ignore
  565. driver.get(self.frontend_url)
  566. self._frontends.append(driver)
  567. return driver
  568. async def get_state(self, token: str) -> BaseState:
  569. """Get the state associated with the given token.
  570. Args:
  571. token: The state token to look up.
  572. Returns:
  573. The state instance associated with the given token
  574. Raises:
  575. RuntimeError: when the app hasn't started running
  576. """
  577. if self.state_manager is None:
  578. raise RuntimeError("state_manager is not set.")
  579. try:
  580. return await self.state_manager.get_state(token)
  581. finally:
  582. if isinstance(self.state_manager, StateManagerRedis):
  583. await self.state_manager.close()
  584. async def set_state(self, token: str, **kwargs) -> None:
  585. """Set the state associated with the given token.
  586. Args:
  587. token: The state token to set.
  588. kwargs: Attributes to set on the state.
  589. Raises:
  590. RuntimeError: when the app hasn't started running
  591. """
  592. if self.state_manager is None:
  593. raise RuntimeError("state_manager is not set.")
  594. state = await self.get_state(token)
  595. for key, value in kwargs.items():
  596. setattr(state, key, value)
  597. try:
  598. await self.state_manager.set_state(token, state)
  599. finally:
  600. if isinstance(self.state_manager, StateManagerRedis):
  601. await self.state_manager.close()
  602. @contextlib.asynccontextmanager
  603. async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
  604. """Modify the state associated with the given token and send update to frontend.
  605. Args:
  606. token: The state token to modify
  607. Yields:
  608. The state instance associated with the given token
  609. Raises:
  610. RuntimeError: when the app hasn't started running
  611. """
  612. if self.state_manager is None:
  613. raise RuntimeError("state_manager is not set.")
  614. if self.app_instance is None:
  615. raise RuntimeError("App is not running.")
  616. app_state_manager = self.app_instance.state_manager
  617. if isinstance(self.state_manager, StateManagerRedis):
  618. # Temporarily replace the app's state manager with our own, since
  619. # the redis connection is on the backend_thread event loop
  620. self.app_instance._state_manager = self.state_manager
  621. try:
  622. async with self.app_instance.modify_state(token) as state:
  623. yield state
  624. finally:
  625. if isinstance(self.state_manager, StateManagerRedis):
  626. self.app_instance._state_manager = app_state_manager
  627. await self.state_manager.close()
  628. def poll_for_content(
  629. self,
  630. element: "WebElement",
  631. timeout: TimeoutType = None,
  632. exp_not_equal: str = "",
  633. ) -> str:
  634. """Poll element.text for change.
  635. Args:
  636. element: selenium webdriver element to check
  637. timeout: how long to poll element.text
  638. exp_not_equal: exit the polling loop when the element text does not match
  639. Returns:
  640. The element text when the polling loop exited
  641. Raises:
  642. TimeoutError: when the timeout expires before text changes
  643. """
  644. if not self._poll_for(
  645. target=lambda: element.text != exp_not_equal,
  646. timeout=timeout,
  647. ):
  648. raise TimeoutError(
  649. f"{element} content remains {exp_not_equal!r} while polling.",
  650. )
  651. return element.text
  652. def poll_for_value(
  653. self,
  654. element: "WebElement",
  655. timeout: TimeoutType = None,
  656. exp_not_equal: str = "",
  657. ) -> Optional[str]:
  658. """Poll element.get_attribute("value") for change.
  659. Args:
  660. element: selenium webdriver element to check
  661. timeout: how long to poll element value attribute
  662. exp_not_equal: exit the polling loop when the value does not match
  663. Returns:
  664. The element value when the polling loop exited
  665. Raises:
  666. TimeoutError: when the timeout expires before value changes
  667. """
  668. if not self._poll_for(
  669. target=lambda: element.get_attribute("value") != exp_not_equal,
  670. timeout=timeout,
  671. ):
  672. raise TimeoutError(
  673. f"{element} content remains {exp_not_equal!r} while polling.",
  674. )
  675. return element.get_attribute("value")
  676. def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
  677. """Poll app state_manager for any connected clients.
  678. Args:
  679. timeout: how long to wait for client states
  680. Returns:
  681. active state instances when the polling loop exited
  682. Raises:
  683. RuntimeError: when the app hasn't started running
  684. TimeoutError: when the timeout expires before any states are seen
  685. """
  686. if self.app_instance is None:
  687. raise RuntimeError("App is not running.")
  688. state_manager = self.app_instance.state_manager
  689. assert isinstance(
  690. state_manager, StateManagerMemory
  691. ), "Only works with memory state manager"
  692. if not self._poll_for(
  693. target=lambda: state_manager.states,
  694. timeout=timeout,
  695. ):
  696. raise TimeoutError("No states were observed while polling.")
  697. return state_manager.states
  698. class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler):
  699. """SimpleHTTPRequestHandler with custom error page handling."""
  700. def __init__(self, *args, error_page_map: dict[int, pathlib.Path], **kwargs):
  701. """Initialize the handler.
  702. Args:
  703. error_page_map: map of error code to error page path
  704. *args: passed through to superclass
  705. **kwargs: passed through to superclass
  706. """
  707. self.error_page_map = error_page_map
  708. super().__init__(*args, **kwargs)
  709. def send_error(
  710. self, code: int, message: str | None = None, explain: str | None = None
  711. ) -> None:
  712. """Send the error page for the given error code.
  713. If the code matches a custom error page, then message and explain are
  714. ignored.
  715. Args:
  716. code: the error code
  717. message: the error message
  718. explain: the error explanation
  719. """
  720. error_page = self.error_page_map.get(code)
  721. if error_page:
  722. self.send_response(code, message)
  723. self.send_header("Connection", "close")
  724. body = error_page.read_bytes()
  725. self.send_header("Content-Type", self.error_content_type)
  726. self.send_header("Content-Length", str(len(body)))
  727. self.end_headers()
  728. self.wfile.write(body)
  729. else:
  730. super().send_error(code, message, explain)
  731. class Subdir404TCPServer(socketserver.TCPServer):
  732. """TCPServer for SimpleHTTPRequestHandlerCustomErrors that serves from a subdir."""
  733. def __init__(
  734. self,
  735. *args,
  736. root: pathlib.Path,
  737. error_page_map: dict[int, pathlib.Path] | None,
  738. **kwargs,
  739. ):
  740. """Initialize the server.
  741. Args:
  742. root: the root directory to serve from
  743. error_page_map: map of error code to error page path
  744. *args: passed through to superclass
  745. **kwargs: passed through to superclass
  746. """
  747. self.root = root
  748. self.error_page_map = error_page_map or {}
  749. super().__init__(*args, **kwargs)
  750. def finish_request(self, request: socket.socket, client_address: tuple[str, int]):
  751. """Finish one request by instantiating RequestHandlerClass.
  752. Args:
  753. request: the requesting socket
  754. client_address: (host, port) referring to the client’s address.
  755. """
  756. self.RequestHandlerClass(
  757. request,
  758. client_address,
  759. self,
  760. directory=str(self.root), # type: ignore
  761. error_page_map=self.error_page_map, # type: ignore
  762. )
  763. class AppHarnessProd(AppHarness):
  764. """AppHarnessProd executes a reflex app in-process for testing.
  765. In prod mode, instead of running `next dev` the app is exported as static
  766. files and served via the builtin python http.server with custom 404 redirect
  767. handling. Additionally, the backend runs in multi-worker mode.
  768. """
  769. frontend_thread: Optional[threading.Thread] = None
  770. frontend_server: Optional[Subdir404TCPServer] = None
  771. def _run_frontend(self):
  772. web_root = (
  773. self.app_path
  774. / reflex.utils.prerequisites.get_web_dir()
  775. / reflex.constants.Dirs.STATIC
  776. )
  777. error_page_map = {
  778. 404: web_root / "404.html",
  779. }
  780. with Subdir404TCPServer(
  781. ("", 0),
  782. SimpleHTTPRequestHandlerCustomErrors,
  783. root=web_root,
  784. error_page_map=error_page_map,
  785. ) as self.frontend_server:
  786. self.frontend_url = "http://localhost:{1}".format(
  787. *self.frontend_server.socket.getsockname()
  788. )
  789. self.frontend_server.serve_forever()
  790. def _start_frontend(self):
  791. # Set up the frontend.
  792. with chdir(self.app_path):
  793. config = reflex.config.get_config()
  794. config.api_url = "http://{0}:{1}".format(
  795. *self._poll_for_servers().getsockname(),
  796. )
  797. reflex.reflex.export(
  798. zipping=False,
  799. frontend=True,
  800. backend=False,
  801. loglevel=reflex.constants.LogLevel.INFO,
  802. )
  803. self.frontend_thread = threading.Thread(target=self._run_frontend)
  804. self.frontend_thread.start()
  805. def _wait_frontend(self):
  806. self._poll_for(lambda: self.frontend_server is not None)
  807. if self.frontend_server is None or not self.frontend_server.socket.fileno():
  808. raise RuntimeError("Frontend did not start")
  809. def _start_backend(self):
  810. if self.app_instance is None:
  811. raise RuntimeError("App was not initialized.")
  812. os.environ[reflex.constants.SKIP_COMPILE_ENV_VAR] = "yes"
  813. self.backend = uvicorn.Server(
  814. uvicorn.Config(
  815. app=self.app_instance,
  816. host="127.0.0.1",
  817. port=0,
  818. workers=reflex.utils.processes.get_num_workers(),
  819. ),
  820. )
  821. self.backend.shutdown = self._get_backend_shutdown_handler()
  822. self.backend_thread = threading.Thread(target=self.backend.run)
  823. self.backend_thread.start()
  824. def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
  825. try:
  826. return super()._poll_for_servers(timeout)
  827. finally:
  828. os.environ.pop(reflex.constants.SKIP_COMPILE_ENV_VAR, None)
  829. def stop(self):
  830. """Stop the frontend python webserver."""
  831. super().stop()
  832. if self.frontend_server is not None:
  833. self.frontend_server.shutdown()
  834. if self.frontend_thread is not None:
  835. self.frontend_thread.join()