testing.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825
  1. """reflex.testing - tools for testing reflex apps."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import dataclasses
  6. import functools
  7. import inspect
  8. import os
  9. import pathlib
  10. import platform
  11. import re
  12. import signal
  13. import socket
  14. import socketserver
  15. import subprocess
  16. import textwrap
  17. import threading
  18. import time
  19. import types
  20. from http.server import SimpleHTTPRequestHandler
  21. from typing import (
  22. TYPE_CHECKING,
  23. Any,
  24. AsyncIterator,
  25. Callable,
  26. Coroutine,
  27. Optional,
  28. Type,
  29. TypeVar,
  30. Union,
  31. )
  32. import psutil
  33. import uvicorn
  34. import reflex
  35. import reflex.reflex
  36. import reflex.utils.build
  37. import reflex.utils.exec
  38. import reflex.utils.prerequisites
  39. import reflex.utils.processes
  40. from reflex.state import BaseState, State, StateManagerMemory, StateManagerRedis
  41. try:
  42. from selenium import webdriver # pyright: ignore [reportMissingImports]
  43. from selenium.webdriver.remote.webdriver import ( # pyright: ignore [reportMissingImports]
  44. WebDriver,
  45. )
  46. if TYPE_CHECKING:
  47. from selenium.webdriver.common.options import (
  48. ArgOptions, # pyright: ignore [reportMissingImports]
  49. )
  50. from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
  51. WebElement,
  52. )
  53. has_selenium = True
  54. except ImportError:
  55. has_selenium = False
  56. # The timeout (minutes) to check for the port.
  57. DEFAULT_TIMEOUT = 15
  58. POLL_INTERVAL = 0.25
  59. FRONTEND_POPEN_ARGS = {}
  60. T = TypeVar("T")
  61. TimeoutType = Optional[Union[int, float]]
  62. if platform.system == "Windows":
  63. FRONTEND_POPEN_ARGS["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore
  64. else:
  65. FRONTEND_POPEN_ARGS["start_new_session"] = True
  66. # Save a copy of internal substates to reset after each test.
  67. INTERNAL_STATES = State.class_subclasses.copy()
  68. # borrowed from py3.11
  69. class chdir(contextlib.AbstractContextManager):
  70. """Non thread-safe context manager to change the current working directory."""
  71. def __init__(self, path):
  72. """Prepare contextmanager.
  73. Args:
  74. path: the path to change to
  75. """
  76. self.path = path
  77. self._old_cwd = []
  78. def __enter__(self):
  79. """Save current directory and perform chdir."""
  80. self._old_cwd.append(os.getcwd())
  81. os.chdir(self.path)
  82. def __exit__(self, *excinfo):
  83. """Change back to previous directory on stack.
  84. Args:
  85. excinfo: sys.exc_info captured in the context block
  86. """
  87. os.chdir(self._old_cwd.pop())
  88. @dataclasses.dataclass
  89. class AppHarness:
  90. """AppHarness executes a reflex app in-process for testing."""
  91. app_name: str
  92. app_source: Optional[types.FunctionType | types.ModuleType] | str
  93. app_path: pathlib.Path
  94. app_module_path: pathlib.Path
  95. app_module: Optional[types.ModuleType] = None
  96. app_instance: Optional[reflex.App] = None
  97. frontend_process: Optional[subprocess.Popen] = None
  98. frontend_url: Optional[str] = None
  99. frontend_output_thread: Optional[threading.Thread] = None
  100. backend_thread: Optional[threading.Thread] = None
  101. backend: Optional[uvicorn.Server] = None
  102. state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
  103. _frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
  104. @classmethod
  105. def create(
  106. cls,
  107. root: pathlib.Path,
  108. app_source: Optional[types.FunctionType | types.ModuleType | str] = None,
  109. app_name: Optional[str] = None,
  110. ) -> "AppHarness":
  111. """Create an AppHarness instance at root.
  112. Args:
  113. root: the directory that will contain the app under test.
  114. app_source: if specified, the source code from this function or module is used
  115. as the main module for the app. It may also be the raw source code text, as a str.
  116. If unspecified, then root must already contain a working reflex app and will be used directly.
  117. app_name: provide the name of the app, otherwise will be derived from app_source or root.
  118. Raises:
  119. ValueError: when app_source is a string and app_name is not provided.
  120. Returns:
  121. AppHarness instance
  122. """
  123. if app_name is None:
  124. if app_source is None:
  125. app_name = root.name.lower()
  126. elif isinstance(app_source, functools.partial):
  127. app_name = app_source.func.__name__.lower()
  128. elif isinstance(app_source, str):
  129. raise ValueError(
  130. "app_name must be provided when app_source is a string."
  131. )
  132. else:
  133. app_name = app_source.__name__.lower()
  134. return cls(
  135. app_name=app_name,
  136. app_source=app_source,
  137. app_path=root,
  138. app_module_path=root / app_name / f"{app_name}.py",
  139. )
  140. def _get_globals_from_signature(self, func: Any) -> dict[str, Any]:
  141. """Get the globals from a function or module object.
  142. Args:
  143. func: function or module object
  144. Returns:
  145. dict of globals
  146. """
  147. overrides = {}
  148. glbs = {}
  149. if not callable(func):
  150. return glbs
  151. if isinstance(func, functools.partial):
  152. overrides = func.keywords
  153. func = func.func
  154. for param in inspect.signature(func).parameters.values():
  155. if param.default is not inspect.Parameter.empty:
  156. glbs[param.name] = param.default
  157. glbs.update(overrides)
  158. return glbs
  159. def _get_source_from_app_source(self, app_source: Any) -> str:
  160. """Get the source from app_source.
  161. Args:
  162. app_source: function or module or str
  163. Returns:
  164. source code
  165. """
  166. if isinstance(app_source, str):
  167. return app_source
  168. source = inspect.getsource(app_source)
  169. source = re.sub(
  170. r"^\s*def\s+\w+\s*\(.*?\)(\s+->\s+\w+)?:", "", source, flags=re.DOTALL
  171. )
  172. return textwrap.dedent(source)
  173. def _initialize_app(self):
  174. os.environ["TELEMETRY_ENABLED"] = "" # disable telemetry reporting for tests
  175. self.app_path.mkdir(parents=True, exist_ok=True)
  176. if self.app_source is not None:
  177. app_globals = self._get_globals_from_signature(self.app_source)
  178. if isinstance(self.app_source, functools.partial):
  179. self.app_source = self.app_source.func # type: ignore
  180. # get the source from a function or module object
  181. source_code = "\n".join(
  182. [
  183. "\n".join(f"{k} = {v!r}" for k, v in app_globals.items()),
  184. self._get_source_from_app_source(self.app_source),
  185. ]
  186. )
  187. with chdir(self.app_path):
  188. reflex.reflex._init(
  189. name=self.app_name,
  190. template=reflex.constants.Templates.Kind.BLANK,
  191. loglevel=reflex.constants.LogLevel.INFO,
  192. )
  193. self.app_module_path.write_text(source_code)
  194. with chdir(self.app_path):
  195. # ensure config and app are reloaded when testing different app
  196. reflex.config.get_config(reload=True)
  197. # Clean out any `rx.page` decorators from other tests.
  198. reflex.app.DECORATED_PAGES.clear()
  199. # reset rx.State subclasses
  200. State.class_subclasses.clear()
  201. State.class_subclasses.update(INTERNAL_STATES)
  202. State._always_dirty_substates = set()
  203. State.get_class_substate.cache_clear()
  204. # Ensure the AppHarness test does not skip State assignment due to running via pytest
  205. os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
  206. self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True)
  207. self.app_instance = self.app_module.app
  208. if isinstance(self.app_instance._state_manager, StateManagerRedis):
  209. # Create our own redis connection for testing.
  210. self.state_manager = StateManagerRedis.create(self.app_instance.state)
  211. else:
  212. self.state_manager = self.app_instance._state_manager
  213. def _get_backend_shutdown_handler(self):
  214. if self.backend is None:
  215. raise RuntimeError("Backend was not initialized.")
  216. original_shutdown = self.backend.shutdown
  217. async def _shutdown_redis(*args, **kwargs) -> None:
  218. # ensure redis is closed before event loop
  219. try:
  220. if self.app_instance is not None and isinstance(
  221. self.app_instance.state_manager, StateManagerRedis
  222. ):
  223. await self.app_instance.state_manager.close()
  224. except ValueError:
  225. pass
  226. await original_shutdown(*args, **kwargs)
  227. return _shutdown_redis
  228. def _start_backend(self, port=0):
  229. if self.app_instance is None:
  230. raise RuntimeError("App was not initialized.")
  231. self.backend = uvicorn.Server(
  232. uvicorn.Config(
  233. app=self.app_instance.api,
  234. host="127.0.0.1",
  235. port=port,
  236. )
  237. )
  238. self.backend.shutdown = self._get_backend_shutdown_handler()
  239. self.backend_thread = threading.Thread(target=self.backend.run)
  240. self.backend_thread.start()
  241. def _start_frontend(self):
  242. # Set up the frontend.
  243. with chdir(self.app_path):
  244. config = reflex.config.get_config()
  245. config.api_url = "http://{0}:{1}".format(
  246. *self._poll_for_servers().getsockname(),
  247. )
  248. reflex.utils.build.setup_frontend(self.app_path)
  249. # Start the frontend.
  250. self.frontend_process = reflex.utils.processes.new_process(
  251. [reflex.utils.prerequisites.get_package_manager(), "run", "dev"],
  252. cwd=self.app_path / reflex.constants.Dirs.WEB,
  253. env={"PORT": "0"},
  254. **FRONTEND_POPEN_ARGS,
  255. )
  256. def _wait_frontend(self):
  257. while self.frontend_url is None:
  258. line = (
  259. self.frontend_process.stdout.readline() # pyright: ignore [reportOptionalMemberAccess]
  260. )
  261. if not line:
  262. break
  263. print(line) # for pytest diagnosis
  264. m = re.search(reflex.constants.Next.FRONTEND_LISTENING_REGEX, line)
  265. if m is not None:
  266. self.frontend_url = m.group(1)
  267. break
  268. if self.frontend_url is None:
  269. raise RuntimeError("Frontend did not start")
  270. def consume_frontend_output():
  271. while True:
  272. line = (
  273. self.frontend_process.stdout.readline() # pyright: ignore [reportOptionalMemberAccess]
  274. )
  275. if not line:
  276. break
  277. print(line)
  278. self.frontend_output_thread = threading.Thread(target=consume_frontend_output)
  279. self.frontend_output_thread.start()
  280. def start(self) -> "AppHarness":
  281. """Start the backend in a new thread and dev frontend as a separate process.
  282. Returns:
  283. self
  284. """
  285. self._initialize_app()
  286. self._start_backend()
  287. self._start_frontend()
  288. self._wait_frontend()
  289. return self
  290. def __enter__(self) -> "AppHarness":
  291. """Contextmanager protocol for `start()`.
  292. Returns:
  293. Instance of AppHarness after calling start()
  294. """
  295. return self.start()
  296. def stop(self) -> None:
  297. """Stop the frontend and backend servers."""
  298. if self.backend is not None:
  299. self.backend.should_exit = True
  300. if self.frontend_process is not None:
  301. # https://stackoverflow.com/a/70565806
  302. frontend_children = psutil.Process(self.frontend_process.pid).children(
  303. recursive=True,
  304. )
  305. if platform.system() == "Windows":
  306. self.frontend_process.terminate()
  307. else:
  308. pgrp = os.getpgid(self.frontend_process.pid)
  309. os.killpg(pgrp, signal.SIGTERM)
  310. # kill any remaining child processes
  311. for child in frontend_children:
  312. # It's okay if the process is already gone.
  313. with contextlib.suppress(psutil.NoSuchProcess):
  314. child.terminate()
  315. _, still_alive = psutil.wait_procs(frontend_children, timeout=3)
  316. for child in still_alive:
  317. # It's okay if the process is already gone.
  318. with contextlib.suppress(psutil.NoSuchProcess):
  319. child.kill()
  320. # wait for main process to exit
  321. self.frontend_process.communicate()
  322. if self.backend_thread is not None:
  323. self.backend_thread.join()
  324. if self.frontend_output_thread is not None:
  325. self.frontend_output_thread.join()
  326. for driver in self._frontends:
  327. driver.quit()
  328. def __exit__(self, *excinfo) -> None:
  329. """Contextmanager protocol for `stop()`.
  330. Args:
  331. excinfo: sys.exc_info captured in the context block
  332. """
  333. self.stop()
  334. @staticmethod
  335. def _poll_for(
  336. target: Callable[[], T],
  337. timeout: TimeoutType = None,
  338. step: TimeoutType = None,
  339. ) -> T | bool:
  340. """Generic polling logic.
  341. Args:
  342. target: callable that returns truthy if polling condition is met.
  343. timeout: max polling time
  344. step: interval between checking target()
  345. Returns:
  346. return value of target() if truthy within timeout
  347. False if timeout elapses
  348. """
  349. if timeout is None:
  350. timeout = DEFAULT_TIMEOUT
  351. if step is None:
  352. step = POLL_INTERVAL
  353. deadline = time.time() + timeout
  354. while time.time() < deadline:
  355. success = target()
  356. if success:
  357. return success
  358. time.sleep(step)
  359. return False
  360. @staticmethod
  361. async def _poll_for_async(
  362. target: Callable[[], Coroutine[None, None, T]],
  363. timeout: TimeoutType = None,
  364. step: TimeoutType = None,
  365. ) -> T | bool:
  366. """Generic polling logic for async functions.
  367. Args:
  368. target: callable that returns truthy if polling condition is met.
  369. timeout: max polling time
  370. step: interval between checking target()
  371. Returns:
  372. return value of target() if truthy within timeout
  373. False if timeout elapses
  374. """
  375. if timeout is None:
  376. timeout = DEFAULT_TIMEOUT
  377. if step is None:
  378. step = POLL_INTERVAL
  379. deadline = time.time() + timeout
  380. while time.time() < deadline:
  381. success = await target()
  382. if success:
  383. return success
  384. await asyncio.sleep(step)
  385. return False
  386. def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
  387. """Poll backend server for listening sockets.
  388. Args:
  389. timeout: how long to wait for listening socket.
  390. Returns:
  391. first active listening socket on the backend
  392. Raises:
  393. RuntimeError: when the backend hasn't started running
  394. TimeoutError: when server or sockets are not ready
  395. """
  396. if self.backend is None:
  397. raise RuntimeError("Backend is not running.")
  398. backend = self.backend
  399. # check for servers to be initialized
  400. if not self._poll_for(
  401. target=lambda: getattr(backend, "servers", False),
  402. timeout=timeout,
  403. ):
  404. raise TimeoutError("Backend servers are not initialized.")
  405. # check for sockets to be listening
  406. if not self._poll_for(
  407. target=lambda: getattr(backend.servers[0], "sockets", False),
  408. timeout=timeout,
  409. ):
  410. raise TimeoutError("Backend is not listening.")
  411. return backend.servers[0].sockets[0]
  412. def frontend(self, driver_clz: Optional[Type["WebDriver"]] = None) -> "WebDriver":
  413. """Get a selenium webdriver instance pointed at the app.
  414. Args:
  415. driver_clz: webdriver.Chrome (default), webdriver.Firefox, webdriver.Safari,
  416. webdriver.Edge, etc
  417. Returns:
  418. Instance of the given webdriver navigated to the frontend url of the app.
  419. Raises:
  420. RuntimeError: when selenium is not importable or frontend is not running
  421. """
  422. if not has_selenium:
  423. raise RuntimeError(
  424. "Frontend functionality requires `selenium` to be installed, "
  425. "and it could not be imported."
  426. )
  427. if self.frontend_url is None:
  428. raise RuntimeError("Frontend is not running.")
  429. want_headless = False
  430. options: ArgOptions | None = None
  431. if os.environ.get("APP_HARNESS_HEADLESS"):
  432. want_headless = True
  433. if driver_clz is None:
  434. requested_driver = os.environ.get("APP_HARNESS_DRIVER", "Chrome")
  435. driver_clz = getattr(webdriver, requested_driver)
  436. options = webdriver.ChromeOptions()
  437. if driver_clz is webdriver.Chrome and want_headless:
  438. options = webdriver.ChromeOptions()
  439. options.add_argument("--headless=new")
  440. elif driver_clz is webdriver.Firefox and want_headless:
  441. options = webdriver.FirefoxOptions()
  442. options.add_argument("-headless")
  443. elif driver_clz is webdriver.Edge and want_headless:
  444. options = webdriver.EdgeOptions()
  445. options.add_argument("headless")
  446. if options and (args := os.environ.get("APP_HARNESS_DRIVER_ARGS")):
  447. for arg in args.split(","):
  448. options.add_argument(arg)
  449. driver = driver_clz(options=options) # type: ignore
  450. driver.get(self.frontend_url)
  451. self._frontends.append(driver)
  452. return driver
  453. async def get_state(self, token: str) -> BaseState:
  454. """Get the state associated with the given token.
  455. Args:
  456. token: The state token to look up.
  457. Returns:
  458. The state instance associated with the given token
  459. Raises:
  460. RuntimeError: when the app hasn't started running
  461. """
  462. if self.state_manager is None:
  463. raise RuntimeError("state_manager is not set.")
  464. try:
  465. return await self.state_manager.get_state(token)
  466. finally:
  467. if isinstance(self.state_manager, StateManagerRedis):
  468. await self.state_manager.close()
  469. async def set_state(self, token: str, **kwargs) -> None:
  470. """Set the state associated with the given token.
  471. Args:
  472. token: The state token to set.
  473. kwargs: Attributes to set on the state.
  474. Raises:
  475. RuntimeError: when the app hasn't started running
  476. """
  477. if self.state_manager is None:
  478. raise RuntimeError("state_manager is not set.")
  479. state = await self.get_state(token)
  480. for key, value in kwargs.items():
  481. setattr(state, key, value)
  482. try:
  483. await self.state_manager.set_state(token, state)
  484. finally:
  485. if isinstance(self.state_manager, StateManagerRedis):
  486. await self.state_manager.close()
  487. @contextlib.asynccontextmanager
  488. async def modify_state(self, token: str) -> AsyncIterator[State]:
  489. """Modify the state associated with the given token and send update to frontend.
  490. Args:
  491. token: The state token to modify
  492. Yields:
  493. The state instance associated with the given token
  494. Raises:
  495. RuntimeError: when the app hasn't started running
  496. """
  497. if self.state_manager is None:
  498. raise RuntimeError("state_manager is not set.")
  499. if self.app_instance is None:
  500. raise RuntimeError("App is not running.")
  501. app_state_manager = self.app_instance.state_manager
  502. if isinstance(self.state_manager, StateManagerRedis):
  503. # Temporarily replace the app's state manager with our own, since
  504. # the redis connection is on the backend_thread event loop
  505. self.app_instance._state_manager = self.state_manager
  506. try:
  507. async with self.app_instance.modify_state(token) as state:
  508. yield state
  509. finally:
  510. if isinstance(self.state_manager, StateManagerRedis):
  511. self.app_instance._state_manager = app_state_manager
  512. await self.state_manager.close()
  513. def poll_for_content(
  514. self,
  515. element: "WebElement",
  516. timeout: TimeoutType = None,
  517. exp_not_equal: str = "",
  518. ) -> str:
  519. """Poll element.text for change.
  520. Args:
  521. element: selenium webdriver element to check
  522. timeout: how long to poll element.text
  523. exp_not_equal: exit the polling loop when the element text does not match
  524. Returns:
  525. The element text when the polling loop exited
  526. Raises:
  527. TimeoutError: when the timeout expires before text changes
  528. """
  529. if not self._poll_for(
  530. target=lambda: element.text != exp_not_equal,
  531. timeout=timeout,
  532. ):
  533. raise TimeoutError(
  534. f"{element} content remains {exp_not_equal!r} while polling.",
  535. )
  536. return element.text
  537. def poll_for_value(
  538. self,
  539. element: "WebElement",
  540. timeout: TimeoutType = None,
  541. exp_not_equal: str = "",
  542. ) -> Optional[str]:
  543. """Poll element.get_attribute("value") for change.
  544. Args:
  545. element: selenium webdriver element to check
  546. timeout: how long to poll element value attribute
  547. exp_not_equal: exit the polling loop when the value does not match
  548. Returns:
  549. The element value when the polling loop exited
  550. Raises:
  551. TimeoutError: when the timeout expires before value changes
  552. """
  553. if not self._poll_for(
  554. target=lambda: element.get_attribute("value") != exp_not_equal,
  555. timeout=timeout,
  556. ):
  557. raise TimeoutError(
  558. f"{element} content remains {exp_not_equal!r} while polling.",
  559. )
  560. return element.get_attribute("value")
  561. def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
  562. """Poll app state_manager for any connected clients.
  563. Args:
  564. timeout: how long to wait for client states
  565. Returns:
  566. active state instances when the polling loop exited
  567. Raises:
  568. RuntimeError: when the app hasn't started running
  569. TimeoutError: when the timeout expires before any states are seen
  570. """
  571. if self.app_instance is None:
  572. raise RuntimeError("App is not running.")
  573. state_manager = self.app_instance.state_manager
  574. assert isinstance(
  575. state_manager, StateManagerMemory
  576. ), "Only works with memory state manager"
  577. if not self._poll_for(
  578. target=lambda: state_manager.states,
  579. timeout=timeout,
  580. ):
  581. raise TimeoutError("No states were observed while polling.")
  582. return state_manager.states
  583. class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler):
  584. """SimpleHTTPRequestHandler with custom error page handling."""
  585. def __init__(self, *args, error_page_map: dict[int, pathlib.Path], **kwargs):
  586. """Initialize the handler.
  587. Args:
  588. error_page_map: map of error code to error page path
  589. *args: passed through to superclass
  590. **kwargs: passed through to superclass
  591. """
  592. self.error_page_map = error_page_map
  593. super().__init__(*args, **kwargs)
  594. def send_error(
  595. self, code: int, message: str | None = None, explain: str | None = None
  596. ) -> None:
  597. """Send the error page for the given error code.
  598. If the code matches a custom error page, then message and explain are
  599. ignored.
  600. Args:
  601. code: the error code
  602. message: the error message
  603. explain: the error explanation
  604. """
  605. error_page = self.error_page_map.get(code)
  606. if error_page:
  607. self.send_response(code, message)
  608. self.send_header("Connection", "close")
  609. body = error_page.read_bytes()
  610. self.send_header("Content-Type", self.error_content_type)
  611. self.send_header("Content-Length", str(len(body)))
  612. self.end_headers()
  613. self.wfile.write(body)
  614. else:
  615. super().send_error(code, message, explain)
  616. class Subdir404TCPServer(socketserver.TCPServer):
  617. """TCPServer for SimpleHTTPRequestHandlerCustomErrors that serves from a subdir."""
  618. def __init__(
  619. self,
  620. *args,
  621. root: pathlib.Path,
  622. error_page_map: dict[int, pathlib.Path] | None,
  623. **kwargs,
  624. ):
  625. """Initialize the server.
  626. Args:
  627. root: the root directory to serve from
  628. error_page_map: map of error code to error page path
  629. *args: passed through to superclass
  630. **kwargs: passed through to superclass
  631. """
  632. self.root = root
  633. self.error_page_map = error_page_map or {}
  634. super().__init__(*args, **kwargs)
  635. def finish_request(self, request: socket.socket, client_address: tuple[str, int]):
  636. """Finish one request by instantiating RequestHandlerClass.
  637. Args:
  638. request: the requesting socket
  639. client_address: (host, port) referring to the client’s address.
  640. """
  641. self.RequestHandlerClass(
  642. request,
  643. client_address,
  644. self,
  645. directory=str(self.root), # type: ignore
  646. error_page_map=self.error_page_map, # type: ignore
  647. )
  648. class AppHarnessProd(AppHarness):
  649. """AppHarnessProd executes a reflex app in-process for testing.
  650. In prod mode, instead of running `next dev` the app is exported as static
  651. files and served via the builtin python http.server with custom 404 redirect
  652. handling. Additionally, the backend runs in multi-worker mode.
  653. """
  654. frontend_thread: Optional[threading.Thread] = None
  655. frontend_server: Optional[Subdir404TCPServer] = None
  656. def _run_frontend(self):
  657. web_root = self.app_path / reflex.constants.Dirs.WEB_STATIC
  658. error_page_map = {
  659. 404: web_root / "404.html",
  660. }
  661. with Subdir404TCPServer(
  662. ("", 0),
  663. SimpleHTTPRequestHandlerCustomErrors,
  664. root=web_root,
  665. error_page_map=error_page_map,
  666. ) as self.frontend_server:
  667. self.frontend_url = "http://localhost:{1}".format(
  668. *self.frontend_server.socket.getsockname()
  669. )
  670. self.frontend_server.serve_forever()
  671. def _start_frontend(self):
  672. # Set up the frontend.
  673. with chdir(self.app_path):
  674. config = reflex.config.get_config()
  675. config.api_url = "http://{0}:{1}".format(
  676. *self._poll_for_servers().getsockname(),
  677. )
  678. reflex.reflex.export(
  679. zipping=False,
  680. frontend=True,
  681. backend=False,
  682. loglevel=reflex.constants.LogLevel.INFO,
  683. )
  684. self.frontend_thread = threading.Thread(target=self._run_frontend)
  685. self.frontend_thread.start()
  686. def _wait_frontend(self):
  687. self._poll_for(lambda: self.frontend_server is not None)
  688. if self.frontend_server is None or not self.frontend_server.socket.fileno():
  689. raise RuntimeError("Frontend did not start")
  690. def _start_backend(self):
  691. if self.app_instance is None:
  692. raise RuntimeError("App was not initialized.")
  693. os.environ[reflex.constants.SKIP_COMPILE_ENV_VAR] = "yes"
  694. self.backend = uvicorn.Server(
  695. uvicorn.Config(
  696. app=self.app_instance,
  697. host="127.0.0.1",
  698. port=0,
  699. workers=reflex.utils.processes.get_num_workers(),
  700. ),
  701. )
  702. self.backend.shutdown = self._get_backend_shutdown_handler()
  703. self.backend_thread = threading.Thread(target=self.backend.run)
  704. self.backend_thread.start()
  705. def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
  706. try:
  707. return super()._poll_for_servers(timeout)
  708. finally:
  709. os.environ.pop(reflex.constants.SKIP_COMPILE_ENV_VAR, None)
  710. def stop(self):
  711. """Stop the frontend python webserver."""
  712. super().stop()
  713. if self.frontend_server is not None:
  714. self.frontend_server.shutdown()
  715. if self.frontend_thread is not None:
  716. self.frontend_thread.join()