12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040 |
- """reflex.testing - tools for testing reflex apps."""
- from __future__ import annotations
- import asyncio
- import contextlib
- import dataclasses
- import functools
- import inspect
- import os
- import platform
- import re
- import signal
- import socket
- import socketserver
- import subprocess
- import textwrap
- import threading
- import time
- import types
- from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
- from http.server import SimpleHTTPRequestHandler
- from pathlib import Path
- from typing import TYPE_CHECKING, Any, TypeVar
- import psutil
- import uvicorn
- import reflex
- import reflex.reflex
- import reflex.utils.build
- import reflex.utils.exec
- import reflex.utils.format
- import reflex.utils.prerequisites
- import reflex.utils.processes
- from reflex.components.component import CustomComponent
- from reflex.config import environment, get_config
- from reflex.state import (
- BaseState,
- StateManager,
- StateManagerDisk,
- StateManagerMemory,
- StateManagerRedis,
- reload_state_module,
- )
- from reflex.utils import console
- from reflex.utils.export import export
- from reflex.utils.types import ASGIApp
- try:
- from selenium import webdriver
- from selenium.webdriver.remote.webdriver import WebDriver
- if TYPE_CHECKING:
- from selenium.webdriver.common.options import ArgOptions
- from selenium.webdriver.remote.webelement import WebElement
- has_selenium = True
- except ImportError:
- has_selenium = False
- # The timeout (minutes) to check for the port.
- DEFAULT_TIMEOUT = 15
- POLL_INTERVAL = 0.25
- FRONTEND_POPEN_ARGS = {}
- T = TypeVar("T")
- TimeoutType = int | float | None
- if platform.system() == "Windows":
- FRONTEND_POPEN_ARGS["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP # pyright: ignore [reportAttributeAccessIssue]
- FRONTEND_POPEN_ARGS["shell"] = True
- else:
- FRONTEND_POPEN_ARGS["start_new_session"] = True
- # borrowed from py3.11
- class chdir(contextlib.AbstractContextManager): # noqa: N801
- """Non thread-safe context manager to change the current working directory."""
- def __init__(self, path: str | Path):
- """Prepare contextmanager.
- Args:
- path: the path to change to
- """
- self.path = path
- self._old_cwd = []
- def __enter__(self):
- """Save current directory and perform chdir."""
- self._old_cwd.append(Path.cwd())
- os.chdir(self.path)
- def __exit__(self, *excinfo):
- """Change back to previous directory on stack.
- Args:
- excinfo: sys.exc_info captured in the context block
- """
- os.chdir(self._old_cwd.pop())
- @dataclasses.dataclass
- class AppHarness:
- """AppHarness executes a reflex app in-process for testing."""
- app_name: str
- app_source: (
- Callable[[], None] | types.ModuleType | str | functools.partial[Any] | None
- )
- app_path: Path
- app_module_path: Path
- app_module: types.ModuleType | None = None
- app_instance: reflex.App | None = None
- app_asgi: ASGIApp | None = None
- frontend_process: subprocess.Popen | None = None
- frontend_url: str | None = None
- frontend_output_thread: threading.Thread | None = None
- backend_thread: threading.Thread | None = None
- backend: uvicorn.Server | None = None
- state_manager: StateManager | None = None
- _frontends: list[WebDriver] = dataclasses.field(default_factory=list)
- @classmethod
- def create(
- cls,
- root: Path,
- app_source: (
- Callable[[], None] | types.ModuleType | str | functools.partial[Any] | None
- ) = None,
- app_name: str | None = None,
- ) -> AppHarness:
- """Create an AppHarness instance at root.
- Args:
- root: the directory that will contain the app under test.
- app_source: if specified, the source code from this function or module is used
- as the main module for the app. It may also be the raw source code text, as a str.
- If unspecified, then root must already contain a working reflex app and will be used directly.
- app_name: provide the name of the app, otherwise will be derived from app_source or root.
- Raises:
- ValueError: when app_source is a string and app_name is not provided.
- Returns:
- AppHarness instance
- """
- if app_name is None:
- if app_source is None:
- app_name = root.name
- elif isinstance(app_source, functools.partial):
- keywords = app_source.keywords
- slug_suffix = "_".join([str(v) for v in keywords.values()])
- func_name = app_source.func.__name__
- app_name = f"{func_name}_{slug_suffix}"
- app_name = re.sub(r"[^a-zA-Z0-9_]", "_", app_name)
- elif isinstance(app_source, str):
- raise ValueError(
- "app_name must be provided when app_source is a string."
- )
- else:
- app_name = app_source.__name__
- app_name = app_name.lower()
- while "__" in app_name:
- app_name = app_name.replace("__", "_")
- return cls(
- app_name=app_name,
- app_source=app_source,
- app_path=root,
- app_module_path=root / app_name / f"{app_name}.py",
- )
- def get_state_name(self, state_cls_name: str) -> str:
- """Get the state name for the given state class name.
- Args:
- state_cls_name: The state class name
- Returns:
- The state name
- """
- return reflex.utils.format.to_snake_case(
- f"{self.app_name}___{self.app_name}___" + state_cls_name
- )
- def get_full_state_name(self, path: list[str]) -> str:
- """Get the full state name for the given state class name.
- Args:
- path: A list of state class names
- Returns:
- The full state name
- """
- # NOTE: using State.get_name() somehow causes trouble here
- # path = [State.get_name()] + [self.get_state_name(p) for p in path] # noqa: ERA001
- path = ["reflex___state____state"] + [self.get_state_name(p) for p in path]
- return ".".join(path)
- def _get_globals_from_signature(self, func: Any) -> dict[str, Any]:
- """Get the globals from a function or module object.
- Args:
- func: function or module object
- Returns:
- dict of globals
- """
- overrides = {}
- glbs = {}
- if not callable(func):
- return glbs
- if isinstance(func, functools.partial):
- overrides = func.keywords
- func = func.func
- for param in inspect.signature(func).parameters.values():
- if param.default is not inspect.Parameter.empty:
- glbs[param.name] = param.default
- glbs.update(overrides)
- return glbs
- def _get_source_from_app_source(self, app_source: Any) -> str:
- """Get the source from app_source.
- Args:
- app_source: function or module or str
- Returns:
- source code
- """
- if isinstance(app_source, str):
- return app_source
- source = inspect.getsource(app_source)
- source = re.sub(
- r"^\s*def\s+\w+\s*\(.*?\)(\s+->\s+\w+)?:", "", source, flags=re.DOTALL
- )
- return textwrap.dedent(source)
- def _initialize_app(self):
- # disable telemetry reporting for tests
- os.environ["TELEMETRY_ENABLED"] = "false"
- CustomComponent.create().get_component.cache_clear()
- self.app_path.mkdir(parents=True, exist_ok=True)
- if self.app_source is not None:
- app_globals = self._get_globals_from_signature(self.app_source)
- if isinstance(self.app_source, functools.partial):
- self.app_source = self.app_source.func
- # get the source from a function or module object
- source_code = "\n".join(
- [
- "\n".join(
- self.get_app_global_source(k, v) for k, v in app_globals.items()
- ),
- self._get_source_from_app_source(self.app_source),
- ]
- )
- get_config().loglevel = reflex.constants.LogLevel.INFO
- with chdir(self.app_path):
- reflex.reflex._init(
- name=self.app_name,
- template=reflex.constants.Templates.DEFAULT,
- )
- self.app_module_path.write_text(source_code)
- else:
- # Just initialize the web folder.
- with chdir(self.app_path):
- reflex.utils.prerequisites.initialize_frontend_dependencies()
- with chdir(self.app_path):
- # ensure config and app are reloaded when testing different app
- reflex.config.get_config(reload=True)
- # Ensure the AppHarness test does not skip State assignment due to running via pytest
- os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
- os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
- # Ensure we actually compile the app during first initialization.
- self.app_instance, self.app_module = (
- reflex.utils.prerequisites.get_and_validate_app(
- # Do not reload the module for pre-existing apps (only apps generated from source)
- reload=self.app_source is not None
- )
- )
- self.app_asgi = self.app_instance()
- if self.app_instance and isinstance(
- self.app_instance._state_manager, StateManagerRedis
- ):
- if self.app_instance._state is None:
- raise RuntimeError("State is not set.")
- # Create our own redis connection for testing.
- self.state_manager = StateManagerRedis.create(self.app_instance._state)
- else:
- self.state_manager = (
- self.app_instance._state_manager if self.app_instance else None
- )
- def _reload_state_module(self):
- """Reload the rx.State module to avoid conflict when reloading."""
- reload_state_module(module=f"{self.app_name}.{self.app_name}")
- def _get_backend_shutdown_handler(self):
- if self.backend is None:
- raise RuntimeError("Backend was not initialized.")
- original_shutdown = self.backend.shutdown
- async def _shutdown(*args, **kwargs) -> None:
- # ensure redis is closed before event loop
- if self.app_instance is not None and isinstance(
- self.app_instance._state_manager, StateManagerRedis
- ):
- with contextlib.suppress(ValueError):
- await self.app_instance._state_manager.close()
- # socketio shutdown handler
- if self.app_instance is not None and self.app_instance.sio is not None:
- with contextlib.suppress(TypeError):
- await self.app_instance.sio.shutdown()
- # sqlalchemy async engine shutdown handler
- try:
- async_engine = reflex.model.get_async_engine(None)
- except ValueError:
- pass
- else:
- await async_engine.dispose()
- await original_shutdown(*args, **kwargs)
- return _shutdown
- def _start_backend(self, port: int = 0):
- if self.app_asgi is None:
- raise RuntimeError("App was not initialized.")
- self.backend = uvicorn.Server(
- uvicorn.Config(
- app=self.app_asgi,
- host="127.0.0.1",
- port=port,
- )
- )
- self.backend.shutdown = self._get_backend_shutdown_handler()
- with chdir(self.app_path):
- print( # noqa: T201
- "Creating backend in a new thread..."
- ) # for pytest diagnosis
- self.backend_thread = threading.Thread(target=self.backend.run)
- self.backend_thread.start()
- print("Backend started.") # for pytest diagnosis #noqa: T201
- async def _reset_backend_state_manager(self):
- """Reset the StateManagerRedis event loop affinity.
- This is necessary when the backend is restarted and the state manager is a
- StateManagerRedis instance.
- Raises:
- RuntimeError: when the state manager cannot be reset
- """
- if (
- self.app_instance is not None
- and isinstance(
- self.app_instance._state_manager,
- StateManagerRedis,
- )
- and self.app_instance._state is not None
- ):
- with contextlib.suppress(RuntimeError):
- await self.app_instance._state_manager.close()
- self.app_instance._state_manager = StateManagerRedis.create(
- state=self.app_instance._state,
- )
- if not isinstance(self.app_instance.state_manager, StateManagerRedis):
- raise RuntimeError("Failed to reset state manager.")
- def _start_frontend(self):
- # Set up the frontend.
- with chdir(self.app_path):
- config = reflex.config.get_config()
- print("Polling for servers...") # for pytest diagnosis #noqa: T201
- config.api_url = "http://{}:{}".format(
- *self._poll_for_servers(timeout=30).getsockname(),
- )
- print("Building frontend...") # for pytest diagnosis #noqa: T201
- reflex.utils.build.setup_frontend(self.app_path)
- print("Frontend starting...") # for pytest diagnosis #noqa: T201
- # Start the frontend.
- self.frontend_process = reflex.utils.processes.new_process(
- [
- *reflex.utils.prerequisites.get_js_package_executor(raise_on_none=True)[
- 0
- ],
- "run",
- "dev",
- ],
- cwd=self.app_path / reflex.utils.prerequisites.get_web_dir(),
- env={"PORT": "0", "NO_COLOR": "1"},
- **FRONTEND_POPEN_ARGS,
- )
- def _wait_frontend(self):
- if self.frontend_process is None or self.frontend_process.stdout is None:
- raise RuntimeError("Frontend process has no stdout.")
- while self.frontend_url is None:
- line = self.frontend_process.stdout.readline()
- if not line:
- break
- print(line) # for pytest diagnosis #noqa: T201
- m = re.search(reflex.constants.ReactRouter.FRONTEND_LISTENING_REGEX, line)
- if m is not None:
- self.frontend_url = m.group(1)
- config = reflex.config.get_config()
- config.deploy_url = self.frontend_url
- break
- if self.frontend_url is None:
- raise RuntimeError("Frontend did not start")
- def consume_frontend_output():
- while True:
- try:
- line = (
- self.frontend_process.stdout.readline() # pyright: ignore [reportOptionalMemberAccess]
- )
- # catch I/O operation on closed file.
- except ValueError as e:
- console.error(str(e))
- break
- if not line:
- break
- self.frontend_output_thread = threading.Thread(target=consume_frontend_output)
- self.frontend_output_thread.start()
- def start(self) -> AppHarness:
- """Start the backend in a new thread and dev frontend as a separate process.
- Returns:
- self
- """
- self._initialize_app()
- self._start_backend()
- self._start_frontend()
- self._wait_frontend()
- return self
- @staticmethod
- def get_app_global_source(key: str, value: Any):
- """Get the source code of a global object.
- If value is a function or class we render the actual
- source of value otherwise we assign value to key.
- Args:
- key: variable name to assign value to.
- value: value of the global variable.
- Returns:
- The rendered app global code.
- """
- if not inspect.isclass(value) and not inspect.isfunction(value):
- return f"{key} = {value!r}"
- return inspect.getsource(value)
- def __enter__(self) -> AppHarness:
- """Contextmanager protocol for `start()`.
- Returns:
- Instance of AppHarness after calling start()
- """
- return self.start()
- def stop(self) -> None:
- """Stop the frontend and backend servers."""
- # Quit browsers first to avoid any lingering events being sent during shutdown.
- for driver in self._frontends:
- driver.quit()
- self._reload_state_module()
- if self.backend is not None:
- self.backend.should_exit = True
- if self.frontend_process is not None:
- # https://stackoverflow.com/a/70565806
- frontend_children = psutil.Process(self.frontend_process.pid).children(
- recursive=True,
- )
- if platform.system() == "Windows":
- self.frontend_process.terminate()
- else:
- pgrp = os.getpgid(self.frontend_process.pid)
- os.killpg(pgrp, signal.SIGTERM)
- # kill any remaining child processes
- for child in frontend_children:
- # It's okay if the process is already gone.
- with contextlib.suppress(psutil.NoSuchProcess):
- child.terminate()
- _, still_alive = psutil.wait_procs(frontend_children, timeout=3)
- for child in still_alive:
- # It's okay if the process is already gone.
- with contextlib.suppress(psutil.NoSuchProcess):
- child.kill()
- # wait for main process to exit
- self.frontend_process.communicate()
- if self.backend_thread is not None:
- self.backend_thread.join()
- if self.frontend_output_thread is not None:
- self.frontend_output_thread.join()
- def __exit__(self, *excinfo) -> None:
- """Contextmanager protocol for `stop()`.
- Args:
- excinfo: sys.exc_info captured in the context block
- """
- self.stop()
- @staticmethod
- def _poll_for(
- target: Callable[[], T],
- timeout: TimeoutType = None,
- step: TimeoutType = None,
- ) -> T | bool:
- """Generic polling logic.
- Args:
- target: callable that returns truthy if polling condition is met.
- timeout: max polling time
- step: interval between checking target()
- Returns:
- return value of target() if truthy within timeout
- False if timeout elapses
- """
- if timeout is None:
- timeout = DEFAULT_TIMEOUT
- if step is None:
- step = POLL_INTERVAL
- deadline = time.time() + timeout
- while time.time() < deadline:
- success = target()
- if success:
- return success
- time.sleep(step)
- return False
- @staticmethod
- async def _poll_for_async(
- target: Callable[[], Coroutine[None, None, T]],
- timeout: TimeoutType = None,
- step: TimeoutType = None,
- ) -> T | bool:
- """Generic polling logic for async functions.
- Args:
- target: callable that returns truthy if polling condition is met.
- timeout: max polling time
- step: interval between checking target()
- Returns:
- return value of target() if truthy within timeout
- False if timeout elapses
- """
- if timeout is None:
- timeout = DEFAULT_TIMEOUT
- if step is None:
- step = POLL_INTERVAL
- deadline = time.time() + timeout
- while time.time() < deadline:
- success = await target()
- if success:
- return success
- await asyncio.sleep(step)
- return False
- def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
- """Poll backend server for listening sockets.
- Args:
- timeout: how long to wait for listening socket.
- Returns:
- first active listening socket on the backend
- Raises:
- RuntimeError: when the backend hasn't started running
- TimeoutError: when server or sockets are not ready
- """
- if self.backend is None:
- raise RuntimeError("Backend is not running.")
- backend = self.backend
- # check for servers to be initialized
- if not self._poll_for(
- target=lambda: getattr(backend, "servers", False),
- timeout=timeout,
- ):
- raise TimeoutError("Backend servers are not initialized.")
- # check for sockets to be listening
- if not self._poll_for(
- target=lambda: getattr(backend.servers[0], "sockets", False),
- timeout=timeout,
- ):
- raise TimeoutError("Backend is not listening.")
- return backend.servers[0].sockets[0]
- def frontend(
- self,
- driver_clz: type[WebDriver] | None = None,
- driver_kwargs: dict[str, Any] | None = None,
- driver_options: ArgOptions | None = None,
- driver_option_args: list[str] | None = None,
- driver_option_capabilities: dict[str, Any] | None = None,
- ) -> WebDriver:
- """Get a selenium webdriver instance pointed at the app.
- Args:
- driver_clz: webdriver.Chrome (default), webdriver.Firefox, webdriver.Safari,
- webdriver.Edge, etc
- driver_kwargs: additional keyword arguments to pass to the webdriver constructor
- driver_options: selenium ArgOptions instance to pass to the webdriver constructor
- driver_option_args: additional arguments for the webdriver options
- driver_option_capabilities: additional capabilities for the webdriver options
- Returns:
- Instance of the given webdriver navigated to the frontend url of the app.
- Raises:
- RuntimeError: when selenium is not importable or frontend is not running
- """
- if not has_selenium:
- raise RuntimeError(
- "Frontend functionality requires `selenium` to be installed, "
- "and it could not be imported."
- )
- if self.frontend_url is None:
- raise RuntimeError("Frontend is not running.")
- want_headless = False
- if environment.APP_HARNESS_HEADLESS.get():
- want_headless = True
- if driver_clz is None:
- requested_driver = environment.APP_HARNESS_DRIVER.get()
- driver_clz = getattr(webdriver, requested_driver) # pyright: ignore [reportPossiblyUnboundVariable]
- if driver_options is None:
- driver_options = getattr(webdriver, f"{requested_driver}Options")() # pyright: ignore [reportPossiblyUnboundVariable]
- if driver_clz is webdriver.Chrome: # pyright: ignore [reportPossiblyUnboundVariable]
- if driver_options is None:
- driver_options = webdriver.ChromeOptions() # pyright: ignore [reportPossiblyUnboundVariable]
- driver_options.add_argument("--class=AppHarness")
- if want_headless:
- driver_options.add_argument("--headless=new")
- elif driver_clz is webdriver.Firefox: # pyright: ignore [reportPossiblyUnboundVariable]
- if driver_options is None:
- driver_options = webdriver.FirefoxOptions() # pyright: ignore [reportPossiblyUnboundVariable]
- if want_headless:
- driver_options.add_argument("-headless")
- elif driver_clz is webdriver.Edge: # pyright: ignore [reportPossiblyUnboundVariable]
- if driver_options is None:
- driver_options = webdriver.EdgeOptions() # pyright: ignore [reportPossiblyUnboundVariable]
- if want_headless:
- driver_options.add_argument("headless")
- if driver_options is None:
- raise RuntimeError(f"Could not determine options for {driver_clz}")
- if args := environment.APP_HARNESS_DRIVER_ARGS.get():
- for arg in args.split(","):
- driver_options.add_argument(arg)
- if driver_option_args is not None:
- for arg in driver_option_args:
- driver_options.add_argument(arg)
- if driver_option_capabilities is not None:
- for key, value in driver_option_capabilities.items():
- driver_options.set_capability(key, value)
- if driver_kwargs is None:
- driver_kwargs = {}
- driver = driver_clz(options=driver_options, **driver_kwargs) # pyright: ignore [reportOptionalCall, reportArgumentType]
- driver.get(self.frontend_url)
- self._frontends.append(driver)
- return driver
- async def get_state(self, token: str) -> BaseState:
- """Get the state associated with the given token.
- Args:
- token: The state token to look up.
- Returns:
- The state instance associated with the given token
- Raises:
- RuntimeError: when the app hasn't started running
- """
- if self.state_manager is None:
- raise RuntimeError("state_manager is not set.")
- try:
- return await self.state_manager.get_state(token)
- finally:
- if isinstance(self.state_manager, StateManagerRedis):
- await self.state_manager.close()
- async def set_state(self, token: str, **kwargs) -> None:
- """Set the state associated with the given token.
- Args:
- token: The state token to set.
- kwargs: Attributes to set on the state.
- Raises:
- RuntimeError: when the app hasn't started running
- """
- if self.state_manager is None:
- raise RuntimeError("state_manager is not set.")
- state = await self.get_state(token)
- for key, value in kwargs.items():
- setattr(state, key, value)
- try:
- await self.state_manager.set_state(token, state)
- finally:
- if isinstance(self.state_manager, StateManagerRedis):
- await self.state_manager.close()
- @contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
- """Modify the state associated with the given token and send update to frontend.
- Args:
- token: The state token to modify
- Yields:
- The state instance associated with the given token
- Raises:
- RuntimeError: when the app hasn't started running
- """
- if self.state_manager is None:
- raise RuntimeError("state_manager is not set.")
- if self.app_instance is None:
- raise RuntimeError("App is not running.")
- app_state_manager = self.app_instance.state_manager
- if isinstance(self.state_manager, StateManagerRedis):
- # Temporarily replace the app's state manager with our own, since
- # the redis connection is on the backend_thread event loop
- self.app_instance._state_manager = self.state_manager
- try:
- async with self.app_instance.modify_state(token) as state:
- yield state
- finally:
- if isinstance(self.state_manager, StateManagerRedis):
- self.app_instance._state_manager = app_state_manager
- await self.state_manager.close()
- def poll_for_content(
- self,
- element: WebElement,
- timeout: TimeoutType = None,
- exp_not_equal: str = "",
- ) -> str:
- """Poll element.text for change.
- Args:
- element: selenium webdriver element to check
- timeout: how long to poll element.text
- exp_not_equal: exit the polling loop when the element text does not match
- Returns:
- The element text when the polling loop exited
- Raises:
- TimeoutError: when the timeout expires before text changes
- """
- if not self._poll_for(
- target=lambda: element.text != exp_not_equal,
- timeout=timeout,
- ):
- raise TimeoutError(
- f"{element} content remains {exp_not_equal!r} while polling.",
- )
- return element.text
- def poll_for_value(
- self,
- element: WebElement,
- timeout: TimeoutType = None,
- exp_not_equal: str | Sequence[str] = "",
- ) -> str | None:
- """Poll element.get_attribute("value") for change.
- Args:
- element: selenium webdriver element to check
- timeout: how long to poll element value attribute
- exp_not_equal: exit the polling loop when the value does not match
- Returns:
- The element value when the polling loop exited
- Raises:
- TimeoutError: when the timeout expires before value changes
- """
- exp_not_equal = (
- (exp_not_equal,) if isinstance(exp_not_equal, str) else exp_not_equal
- )
- if not self._poll_for(
- target=lambda: element.get_attribute("value") not in exp_not_equal,
- timeout=timeout,
- ):
- raise TimeoutError(
- f"{element} content remains {exp_not_equal!r} while polling.",
- )
- return element.get_attribute("value")
- def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
- """Poll app state_manager for any connected clients.
- Args:
- timeout: how long to wait for client states
- Returns:
- active state instances when the polling loop exited
- Raises:
- RuntimeError: when the app hasn't started running
- TimeoutError: when the timeout expires before any states are seen
- ValueError: when the state_manager is not a memory state manager
- """
- if self.app_instance is None:
- raise RuntimeError("App is not running.")
- state_manager = self.app_instance.state_manager
- if not isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
- raise ValueError("Only works with memory or disk state manager")
- if not self._poll_for(
- target=lambda: state_manager.states,
- timeout=timeout,
- ):
- raise TimeoutError("No states were observed while polling.")
- return state_manager.states
- @staticmethod
- def poll_for_result(
- f: Callable[[], T],
- exception: type[Exception] = Exception,
- max_attempts: int = 5,
- seconds_between_attempts: int = 1,
- ) -> T:
- """Poll for a result from a function.
- Args:
- f: function to call
- exception: exception to catch
- max_attempts: maximum number of attempts
- seconds_between_attempts: seconds to wait between
- Returns:
- Result of the function
- Raises:
- AssertionError: if the function does not return a value
- """
- attempts = 0
- while attempts < max_attempts:
- try:
- return f()
- except exception: # noqa: PERF203
- attempts += 1
- time.sleep(seconds_between_attempts)
- raise AssertionError("Function did not return a value")
- class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler):
- """SimpleHTTPRequestHandler with custom error page handling."""
- def __init__(self, *args, error_page_map: dict[int, Path], **kwargs):
- """Initialize the handler.
- Args:
- error_page_map: map of error code to error page path
- *args: passed through to superclass
- **kwargs: passed through to superclass
- """
- self.error_page_map = error_page_map
- super().__init__(*args, **kwargs)
- def send_error(
- self, code: int, message: str | None = None, explain: str | None = None
- ) -> None:
- """Send the error page for the given error code.
- If the code matches a custom error page, then message and explain are
- ignored.
- Args:
- code: the error code
- message: the error message
- explain: the error explanation
- """
- error_page = self.error_page_map.get(code)
- if error_page:
- self.send_response(code, message)
- self.send_header("Connection", "close")
- body = error_page.read_bytes()
- self.send_header("Content-Type", self.error_content_type)
- self.send_header("Content-Length", str(len(body)))
- self.end_headers()
- self.wfile.write(body)
- else:
- super().send_error(code, message, explain)
- class Subdir404TCPServer(socketserver.TCPServer):
- """TCPServer for SimpleHTTPRequestHandlerCustomErrors that serves from a subdir."""
- def __init__(
- self,
- *args,
- root: Path,
- error_page_map: dict[int, Path] | None,
- **kwargs,
- ):
- """Initialize the server.
- Args:
- root: the root directory to serve from
- error_page_map: map of error code to error page path
- *args: passed through to superclass
- **kwargs: passed through to superclass
- """
- self.root = root
- self.error_page_map = error_page_map or {}
- super().__init__(*args, **kwargs)
- def finish_request(self, request: socket.socket, client_address: tuple[str, int]):
- """Finish one request by instantiating RequestHandlerClass.
- Args:
- request: the requesting socket
- client_address: (host, port) referring to the client's address.
- """
- self.RequestHandlerClass(
- request,
- client_address,
- self,
- directory=str(self.root), # pyright: ignore [reportCallIssue]
- error_page_map=self.error_page_map, # pyright: ignore [reportCallIssue]
- )
- class AppHarnessProd(AppHarness):
- """AppHarnessProd executes a reflex app in-process for testing.
- In prod mode, instead of running `react-router dev` the app is exported as static
- files and served via the builtin python http.server with custom 404 redirect
- handling. Additionally, the backend runs in multi-worker mode.
- """
- frontend_thread: threading.Thread | None = None
- frontend_server: Subdir404TCPServer | None = None
- def _run_frontend(self):
- web_root = (
- self.app_path
- / reflex.utils.prerequisites.get_web_dir()
- / reflex.constants.Dirs.STATIC
- )
- error_page_map = {
- 404: web_root / "404" / "index.html",
- }
- with Subdir404TCPServer(
- ("", 0),
- SimpleHTTPRequestHandlerCustomErrors,
- root=web_root,
- error_page_map=error_page_map,
- ) as self.frontend_server:
- self.frontend_url = "http://localhost:{1}".format(
- *self.frontend_server.socket.getsockname()
- )
- self.frontend_server.serve_forever()
- def _start_frontend(self):
- # Set up the frontend.
- with chdir(self.app_path):
- config = reflex.config.get_config()
- print("Polling for servers...") # for pytest diagnosis #noqa: T201
- config.api_url = "http://{}:{}".format(
- *self._poll_for_servers(timeout=30).getsockname(),
- )
- print("Building frontend...") # for pytest diagnosis #noqa: T201
- get_config().loglevel = reflex.constants.LogLevel.INFO
- reflex.utils.prerequisites.assert_in_reflex_dir()
- if reflex.utils.prerequisites.needs_reinit():
- reflex.reflex._init(name=get_config().app_name)
- export(
- zipping=False,
- frontend=True,
- backend=False,
- loglevel=reflex.constants.LogLevel.INFO,
- env=reflex.constants.Env.PROD,
- )
- print("Frontend starting...") # for pytest diagnosis #noqa: T201
- self.frontend_thread = threading.Thread(target=self._run_frontend)
- self.frontend_thread.start()
- def _wait_frontend(self):
- self._poll_for(lambda: self.frontend_server is not None)
- if self.frontend_server is None or not self.frontend_server.socket.fileno():
- raise RuntimeError("Frontend did not start")
- def _start_backend(self):
- if self.app_asgi is None:
- raise RuntimeError("App was not initialized.")
- environment.REFLEX_SKIP_COMPILE.set(True)
- self.backend = uvicorn.Server(
- uvicorn.Config(
- app=self.app_asgi,
- host="127.0.0.1",
- port=0,
- workers=reflex.utils.processes.get_num_workers(),
- ),
- )
- self.backend.shutdown = self._get_backend_shutdown_handler()
- print( # noqa: T201
- "Creating backend in a new thread..."
- )
- self.backend_thread = threading.Thread(target=self.backend.run)
- self.backend_thread.start()
- print("Backend started.") # for pytest diagnosis #noqa: T201
- def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
- try:
- return super()._poll_for_servers(timeout)
- finally:
- environment.REFLEX_SKIP_COMPILE.set(None)
- def stop(self):
- """Stop the frontend python webserver."""
- super().stop()
- if self.frontend_server is not None:
- self.frontend_server.shutdown()
- if self.frontend_thread is not None:
- self.frontend_thread.join()
|