浏览代码

Client-side Routing (404 redirect) (#1695)

Masen Furer 1 年之前
父节点
当前提交
38c5503f94

+ 3 - 7
docker-example/Caddyfile

@@ -7,12 +7,8 @@ handle @backend_routes {
 	reverse_proxy app:8000
 }
 
+root * /srv
 route {
-	try_files {path} {path}.html
-	file_server {
-		root /srv
-		pass_thru
-	}
-	# proxy dynamic routes to nextjs server
-	reverse_proxy app:3000
+	try_files {path} {path}/ /404.html
+	file_server
 }

+ 6 - 5
docker-example/Dockerfile

@@ -8,9 +8,6 @@ ARG API_URL
 WORKDIR /app
 COPY . .
 
-# Reflex will install bun, nvm, and node to `$HOME/.reflex` (/app/.reflex)
-ENV HOME=/app
-
 # Create virtualenv which will be copied into final container
 ENV VIRTUAL_ENV=/app/.venv
 ENV PATH="$VIRTUAL_ENV/bin:$PATH"
@@ -22,9 +19,13 @@ RUN pip install -r requirements.txt
 # Deploy templates and prepare app
 RUN reflex init
 
-# Export static copy of frontend to /app/.web/_static (and pre-install frontend packages)
+# Export static copy of frontend to /app/.web/_static
 RUN reflex export --frontend-only --no-zip
 
+# Copy static files out of /app to save space in backend image
+RUN mv .web/_static /tmp/_static
+RUN rm -rf .web && mkdir .web
+RUN mv /tmp/_static .web/_static
 
 # Stage 2: copy artifacts into slim image 
 FROM python:3.11-slim
@@ -35,4 +36,4 @@ COPY --chown=reflex --from=init /app /app
 USER reflex
 ENV PATH="/app/.venv/bin:$PATH" API_URL=$API_URL
 
-CMD reflex db migrate && reflex run --env prod
+CMD reflex db migrate && reflex run --env prod --backend-only

+ 17 - 0
integration/conftest.py

@@ -5,6 +5,8 @@ from pathlib import Path
 
 import pytest
 
+from reflex.testing import AppHarness, AppHarnessProd
+
 DISPLAY = None
 XVFB_DIMENSIONS = (800, 600)
 
@@ -57,3 +59,18 @@ def pytest_exception_interact(node, call, report):
         )
     except Exception as e:
         print(f"Failed to take screenshot for {node}: {e}")
+
+
+@pytest.fixture(
+    scope="session", params=[AppHarness, AppHarnessProd], ids=["dev", "prod"]
+)
+def app_harness_env(request):
+    """Parametrize the AppHarness class to use for the test, either dev or prod.
+
+    Args:
+        request: The pytest fixture request object.
+
+    Returns:
+        The AppHarness class to use for the test.
+    """
+    return request.param

+ 139 - 32
integration/test_dynamic_routes.py

@@ -1,12 +1,12 @@
 """Integration tests for dynamic route page behavior."""
-import time
-from typing import Generator
+from typing import Callable, Generator, Type
 from urllib.parse import urlsplit
 
 import pytest
 from selenium.webdriver.common.by import By
 
-from reflex.testing import AppHarness
+from reflex import State
+from reflex.testing import AppHarness, AppHarnessProd, WebDriver
 
 from .utils import poll_for_navigation
 
@@ -20,7 +20,14 @@ def DynamicRoute():
         page_id: str = ""
 
         def on_load(self):
-            self.order.append(self.page_id or "no page id")
+            self.order.append(
+                f"{self.get_current_page()}-{self.page_id or 'no page id'}"
+            )
+
+        def on_load_redir(self):
+            query_params = self.get_query_params()
+            self.order.append(f"on_load_redir-{query_params}")
+            return rx.redirect(f"/page/{query_params['page_id']}")
 
         @rx.var
         def next_page(self) -> str:
@@ -42,37 +49,46 @@ def DynamicRoute():
             rx.link(
                 "next", href="/page/" + DynamicState.next_page, id="link_page_next"  # type: ignore
             ),
+            rx.link("missing", href="/missing", id="link_missing"),
             rx.list(
                 rx.foreach(DynamicState.order, lambda i: rx.list_item(rx.text(i))),  # type: ignore
             ),
         )
 
+    @rx.page(route="/redirect-page/[page_id]", on_load=DynamicState.on_load_redir)  # type: ignore
+    def redirect_page():
+        return rx.fragment(rx.text("redirecting..."))
+
     app = rx.App(state=DynamicState)
     app.add_page(index)
     app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load)  # type: ignore
     app.add_page(index, route="/static/x", on_load=DynamicState.on_load)  # type: ignore
+    app.add_custom_404_page(on_load=DynamicState.on_load)  # type: ignore
     app.compile()
 
 
 @pytest.fixture(scope="session")
-def dynamic_route(tmp_path_factory) -> Generator[AppHarness, None, None]:
+def dynamic_route(
+    app_harness_env: Type[AppHarness], tmp_path_factory
+) -> Generator[AppHarness, None, None]:
     """Start DynamicRoute app at tmp_path via AppHarness.
 
     Args:
+        app_harness_env: either AppHarness (dev) or AppHarnessProd (prod)
         tmp_path_factory: pytest tmp_path_factory fixture
 
     Yields:
         running AppHarness instance
     """
-    with AppHarness.create(
-        root=tmp_path_factory.mktemp("dynamic_route"),
+    with app_harness_env.create(
+        root=tmp_path_factory.mktemp(f"dynamic_route"),
         app_source=DynamicRoute,  # type: ignore
     ) as harness:
         yield harness
 
 
 @pytest.fixture
-def driver(dynamic_route: AppHarness):
+def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]:
     """Get an instance of the browser open to the dynamic_route app.
 
     Args:
@@ -90,23 +106,71 @@ def driver(dynamic_route: AppHarness):
         driver.quit()
 
 
-def test_on_load_navigate(dynamic_route: AppHarness, driver):
-    """Click links to navigate between dynamic pages with on_load event.
+@pytest.fixture()
+def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
+    """Get the backend state.
 
     Args:
         dynamic_route: harness for DynamicRoute app.
         driver: WebDriver instance.
+
+    Returns:
+        The backend state associated with the token visible in the driver browser.
     """
     assert dynamic_route.app_instance is not None
     token_input = driver.find_element(By.ID, "token")
-    link = driver.find_element(By.ID, "link_page_next")
     assert token_input
-    assert link
 
     # wait for the backend connection to send the token
     token = dynamic_route.poll_for_value(token_input)
     assert token is not None
 
+    # look up the backend state from the state manager
+    return dynamic_route.app_instance.state_manager.states[token]
+
+
+@pytest.fixture()
+def poll_for_order(
+    dynamic_route: AppHarness, backend_state: State
+) -> Callable[[list[str]], None]:
+    """Poll for the order list to match the expected order.
+
+    Args:
+        dynamic_route: harness for DynamicRoute app.
+        backend_state: The backend state associated with the token visible in the driver browser.
+
+    Returns:
+        A function that polls for the order list to match the expected order.
+    """
+
+    def _poll_for_order(exp_order: list[str]):
+        dynamic_route._poll_for(lambda: backend_state.order == exp_order)
+        assert backend_state.order == exp_order
+
+    return _poll_for_order
+
+
+def test_on_load_navigate(
+    dynamic_route: AppHarness,
+    driver: WebDriver,
+    backend_state: State,
+    poll_for_order: Callable[[list[str]], None],
+):
+    """Click links to navigate between dynamic pages with on_load event.
+
+    Args:
+        dynamic_route: harness for DynamicRoute app.
+        driver: WebDriver instance.
+        backend_state: The backend state associated with the token visible in the driver browser.
+        poll_for_order: function that polls for the order list to match the expected order.
+    """
+    assert dynamic_route.app_instance is not None
+    is_prod = isinstance(dynamic_route, AppHarnessProd)
+    link = driver.find_element(By.ID, "link_page_next")
+    assert link
+
+    exp_order = [f"/page/[page-id]-{ix}" for ix in range(10)]
+
     # click the link a few times
     for ix in range(10):
         # wait for navigation, then assert on url
@@ -121,40 +185,84 @@ def test_on_load_navigate(dynamic_route: AppHarness, driver):
         assert page_id_input
 
         assert dynamic_route.poll_for_value(page_id_input) == str(ix)
+    poll_for_order(exp_order)
 
-    # look up the backend state and assert that `on_load` was called for all
-    # navigation events
-    backend_state = dynamic_route.app_instance.state_manager.states[token]
-    time.sleep(0.2)
-    assert backend_state.order == [str(ix) for ix in range(10)]
+    # manually load the next page to trigger client side routing in prod mode
+    if is_prod:
+        exp_order += ["/404-no page id"]
+    exp_order += ["/page/[page-id]-10"]
+    with poll_for_navigation(driver):
+        driver.get(f"{dynamic_route.frontend_url}/page/10/")
+    poll_for_order(exp_order)
 
+    # make sure internal nav still hydrates after redirect
+    exp_order += ["/page/[page-id]-11"]
+    link = driver.find_element(By.ID, "link_page_next")
+    with poll_for_navigation(driver):
+        link.click()
+    poll_for_order(exp_order)
 
-def test_on_load_navigate_non_dynamic(dynamic_route: AppHarness, driver):
-    """Click links to navigate between static pages with on_load event.
+    # load same page with a query param and make sure it passes through
+    if is_prod:
+        exp_order += ["/404-no page id"]
+    exp_order += ["/page/[page-id]-11"]
+    with poll_for_navigation(driver):
+        driver.get(f"{driver.current_url}?foo=bar")
+    poll_for_order(exp_order)
+    assert backend_state.get_query_params()["foo"] == "bar"
+
+    # hit a 404 and ensure we still hydrate
+    exp_order += ["/404-no page id"]
+    with poll_for_navigation(driver):
+        driver.get(f"{dynamic_route.frontend_url}/missing")
+    poll_for_order(exp_order)
+
+    # browser nav should still trigger hydration
+    if is_prod:
+        exp_order += ["/404-no page id"]
+    exp_order += ["/page/[page-id]-11"]
+    with poll_for_navigation(driver):
+        driver.back()
+    poll_for_order(exp_order)
+
+    # next/link to a 404 and ensure we still hydrate
+    exp_order += ["/404-no page id"]
+    link = driver.find_element(By.ID, "link_missing")
+    with poll_for_navigation(driver):
+        link.click()
+    poll_for_order(exp_order)
+
+    # hit a page that redirects back to dynamic page
+    if is_prod:
+        exp_order += ["/404-no page id"]
+    exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page-id]-0"]
+    with poll_for_navigation(driver):
+        driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar")
+    poll_for_order(exp_order)
+    # should have redirected back to page 0
+    assert urlsplit(driver.current_url).path == "/page/0/"
 
 
+def test_on_load_navigate_non_dynamic(
+    dynamic_route: AppHarness,
+    driver: WebDriver,
+    poll_for_order: Callable[[list[str]], None],
+):
+    """Click links to navigate between static pages with on_load event.
+
     Args:
         dynamic_route: harness for DynamicRoute app.
         driver: WebDriver instance.
+        poll_for_order: function that polls for the order list to match the expected order.
     """
     assert dynamic_route.app_instance is not None
-    token_input = driver.find_element(By.ID, "token")
     link = driver.find_element(By.ID, "link_page_x")
-    assert token_input
     assert link
 
-    # wait for the backend connection to send the token
-    token = dynamic_route.poll_for_value(token_input)
-    assert token is not None
-
     with poll_for_navigation(driver):
         link.click()
     assert urlsplit(driver.current_url).path == "/static/x/"
-
-    # look up the backend state and assert that `on_load` was called once
-    backend_state = dynamic_route.app_instance.state_manager.states[token]
-    time.sleep(0.2)
-    assert backend_state.order == ["no page id"]
+    poll_for_order(["/static/x-no page id"])
 
     # go back to the index and navigate back to the static route
     link = driver.find_element(By.ID, "link_index")
@@ -166,5 +274,4 @@ def test_on_load_navigate_non_dynamic(dynamic_route: AppHarness, driver):
     with poll_for_navigation(driver):
         link.click()
     assert urlsplit(driver.current_url).path == "/static/x/"
-    time.sleep(0.2)
-    assert backend_state.order == ["no page id", "no page id"]
+    poll_for_order(["/static/x-no page id", "/static/x-no page id"])

+ 0 - 19
reflex/.templates/web/pages/404.js

@@ -1,19 +0,0 @@
-import Router from "next/router";
-import { useEffect, useState } from "react";
-
-export default function Custom404() {
-  const [isNotFound, setIsNotFound] = useState(false);
-
-  useEffect(() => {
-    const pathNameArray = window.location.pathname.split("/");
-    if (pathNameArray.length == 2 && pathNameArray[1] == "404") {
-      setIsNotFound(true);
-    } else {
-      Router.replace(window.location.pathname);
-    }
-  }, []);
-
-  if (isNotFound) return <h1>404 - Page Not Found</h1>;
-
-  return null;
-}

+ 36 - 0
reflex/.templates/web/utils/client_side_routing.js

@@ -0,0 +1,36 @@
+import { useEffect, useRef, useState } from "react";
+import { useRouter } from "next/router";
+
+/**
+ * React hook for use in /404 page to enable client-side routing.
+ *
+ * Uses the next/router to redirect to the provided URL when loading
+ * the 404 page (for example as a fallback in static hosting situations).
+ *
+ * @returns {boolean} routeNotFound - true if the current route is an actual 404
+ */
+export const useClientSideRouting = () => {
+  const [routeNotFound, setRouteNotFound] = useState(false)
+  const didRedirect = useRef(false)
+  const router = useRouter()
+  useEffect(() => {
+    if (
+      router.isReady &&
+      !didRedirect.current  // have not tried redirecting yet
+    ) {
+      didRedirect.current = true  // never redirect twice to avoid "Hard Navigate" error
+      // attempt to redirect to the route in the browser address bar once
+      router.replace({
+          pathname: window.location.pathname,
+          query: window.location.search.slice(1),
+      })
+      .catch((e) => {
+        setRouteNotFound(true)  // navigation failed, so this is a real 404
+      })
+    }
+  }, [router.isReady]);
+
+  // Return the reactive bool, to avoid flashing 404 page until we know for sure
+  // the route is not found.
+  return routeNotFound
+}

+ 18 - 9
reflex/.templates/web/utils/state.js

@@ -173,10 +173,13 @@ export const applyEvent = async (event, socket) => {
     return false;
   }
 
-  // Send the event to the server.
-  event.token = getToken();
-  event.router_data = (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(Router);
+  // Update token and router data (if missing).
+  event.token = getToken()
+  if (event.router_data === undefined || Object.keys(event.router_data).length === 0) {
+    event.router_data = (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(Router)
+  }
 
+  // Send the event to the server.
   if (socket) {
     socket.emit("event", JSON.stringify(event));
     return true;
@@ -255,7 +258,6 @@ export const processEvent = async (
  * @param dispatch The function to queue state update
  * @param transports The transports to use.
  * @param setConnectError The function to update connection error value.
- * @param initial_events Array of events to seed the queue after connecting.
  * @param client_storage The client storage object from context.js
  */
 export const connect = async (
@@ -263,7 +265,6 @@ export const connect = async (
   dispatch,
   transports,
   setConnectError,
-  initial_events = [],
   client_storage = {},
 ) => {
   // Get backend URL object from the endpoint.
@@ -277,7 +278,6 @@ export const connect = async (
 
   // Once the socket is open, hydrate the page.
   socket.current.on("connect", () => {
-    queueEvents(initial_events, socket)
     setConnectError(null)
   });
 
@@ -427,8 +427,8 @@ const applyClientStorageDelta = (client_storage, delta) => {
 
 /**
  * Establish websocket event loop for a NextJS page.
- * @param initial_state The initial page state.
- * @param initial_events Array of events to seed the queue after connecting.
+ * @param initial_state The initial app state.
+ * @param initial_events The initial app events.
  * @param client_storage The client storage object from context.js
  *
  * @returns [state, Event, connectError] -
@@ -452,6 +452,15 @@ export const useEventLoop = (
       queueEvents(events, socket)
   }
 
+  const sentHydrate = useRef(false);  // Avoid double-hydrate due to React strict-mode
+  // initial state hydrate
+  useEffect(() => {
+    if (router.isReady && !sentHydrate.current) {
+      Event(initial_events.map((e) => ({...e})))
+      sentHydrate.current = true
+    }
+  }, [router.isReady])
+
   // Main event loop.
   useEffect(() => {
     // Skip if the router is not ready.
@@ -461,7 +470,7 @@ export const useEventLoop = (
 
     // Initialize the websocket connection.
     if (!socket.current) {
-      connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events, client_storage)
+      connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
     }
     (async () => {
       // Process all outstanding events.

+ 11 - 1
reflex/app.py

@@ -32,6 +32,10 @@ from reflex.compiler import utils as compiler_utils
 from reflex.components import connection_modal
 from reflex.components.component import Component, ComponentStyle
 from reflex.components.layout.fragment import Fragment
+from reflex.components.navigation.client_side_routing import (
+    Default404Page,
+    wait_for_client_redirect,
+)
 from reflex.config import get_config
 from reflex.event import Event, EventHandler, EventSpec
 from reflex.middleware import HydrateMiddleware, Middleware
@@ -451,8 +455,10 @@ class App(Base):
             on_load: The event handler(s) that will be called each time the page load.
             meta: The metadata of the page.
         """
+        if component is None:
+            component = Default404Page.create()
         self.add_page(
-            component=component if component else Fragment.create(),
+            component=wait_for_client_redirect(self._generate_component(component)),
             route=constants.SLUG_404,
             title=title or constants.TITLE_404,
             image=image or constants.FAVICON_404,
@@ -533,6 +539,10 @@ class App(Base):
         for render, kwargs in DECORATED_PAGES:
             self.add_page(render, **kwargs)
 
+        # Render a default 404 page if the user didn't supply one
+        if constants.SLUG_404 not in self.pages:
+            self.add_custom_404_page()
+
         task = progress.add_task("Compiling: ", total=len(self.pages))
         # TODO: include all work done in progress indicator, not just self.pages
 

+ 1 - 1
reflex/compiler/compiler.py

@@ -276,5 +276,5 @@ def compile_tailwind(
 
 def purge_web_pages_dir():
     """Empty out .web directory."""
-    template_files = ["_app.js", "404.js"]
+    template_files = ["_app.js"]
     utils.empty_dir(constants.WEB_PAGES_DIR, keep_files=template_files)

+ 69 - 0
reflex/components/navigation/client_side_routing.py

@@ -0,0 +1,69 @@
+"""Handle dynamic routes in static exports via client-side routing.
+
+Works with /utils/client_side_routing.js to handle the redirect and state.
+
+When the user hits a 404 accessing a route, redirect them to the same page,
+setting a reactive state var "routeNotFound" to true if the redirect fails.  The
+`wait_for_client_redirect` function will render the component only after
+routeNotFound becomes true.
+"""
+from __future__ import annotations
+
+from reflex import constants
+
+from ...vars import Var
+from ..component import Component
+from ..layout.cond import Cond
+
+route_not_found = Var.create_safe(constants.ROUTE_NOT_FOUND)
+
+
+class ClientSideRouting(Component):
+    """The client-side routing component."""
+
+    library = "/utils/client_side_routing"
+    tag = "useClientSideRouting"
+
+    def _get_hooks(self) -> str:
+        """Get the hooks to render.
+
+        Returns:
+            The useClientSideRouting hook.
+        """
+        return f"const {constants.ROUTE_NOT_FOUND} = {self.tag}()"
+
+    def render(self) -> str:
+        """Render the component.
+
+        Returns:
+            Empty string, because this component is only used for its hooks.
+        """
+        return ""
+
+
+def wait_for_client_redirect(component) -> Component:
+    """Wait for a redirect to occur before rendering a component.
+
+    This prevents the 404 page from flashing while the redirect is happening.
+
+    Args:
+        component: The component to render after the redirect.
+
+    Returns:
+        The conditionally rendered component.
+    """
+    return Cond.create(
+        cond=route_not_found,
+        comp1=component,
+        comp2=ClientSideRouting.create(),
+    )
+
+
+class Default404Page(Component):
+    """The NextJS default 404 page."""
+
+    library = "next/error"
+    tag = "Error"
+    is_default = True
+
+    status_code: Var[int] = 404  # type: ignore

+ 1 - 0
reflex/constants.py

@@ -344,6 +344,7 @@ SLUG_404 = "404"
 TITLE_404 = "404 - Not Found"
 FAVICON_404 = "favicon.ico"
 DESCRIPTION_404 = "The page was not found"
+ROUTE_NOT_FOUND = "routeNotFound"
 
 # Color mode variables
 USE_COLOR_MODE = "useColorMode"

+ 7 - 1
reflex/event.py

@@ -449,12 +449,17 @@ def get_handler_args(event_spec: EventSpec, arg: Var) -> tuple[tuple[Var, Var],
     return event_spec.args if len(args) > 1 else tuple()
 
 
-def fix_events(events: list[EventHandler | EventSpec], token: str) -> list[Event]:
+def fix_events(
+    events: list[EventHandler | EventSpec],
+    token: str,
+    router_data: dict[str, Any] | None = None,
+) -> list[Event]:
     """Fix a list of events returned by an event handler.
 
     Args:
         events: The events to fix.
         token: The user token.
+        router_data: The optional router data to set in the event.
 
     Returns:
         The fixed events.
@@ -485,6 +490,7 @@ def fix_events(events: list[EventHandler | EventSpec], token: str) -> list[Event
                 token=token,
                 name=name,
                 payload=payload,
+                router_data=router_data or {},
             )
         )
 

+ 1 - 1
reflex/middleware/hydrate_middleware.py

@@ -60,7 +60,7 @@ class HydrateMiddleware(Middleware):
 
         # Add the on_load events and set is_hydrated to True.
         events = [*app.get_load_events(route), type(state).set_is_hydrated(True)]  # type: ignore
-        events = fix_events(events, event.token)
+        events = fix_events(events, event.token, router_data=event.router_data)
 
         # Return the state update.
         return StateUpdate(delta=delta, events=events)

+ 162 - 2
reflex/testing.py

@@ -10,11 +10,13 @@ import platform
 import re
 import signal
 import socket
+import socketserver
 import subprocess
 import textwrap
 import threading
 import time
 import types
+from http.server import SimpleHTTPRequestHandler
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -156,9 +158,9 @@ class AppHarness:
                 )
                 self.app_module_path.write_text(source_code)
         with chdir(self.app_path):
-            # ensure config is reloaded when testing different app
+            # ensure config and app are reloaded when testing different app
             reflex.config.get_config(reload=True)
-            self.app_module = reflex.utils.prerequisites.get_app()
+            self.app_module = reflex.utils.prerequisites.get_app(reload=True)
         self.app_instance = self.app_module.app
 
     def _start_backend(self):
@@ -461,3 +463,161 @@ class AppHarness:
         ):
             raise TimeoutError("No states were observed while polling.")
         return state_manager.states
+
+
+class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler):
+    """SimpleHTTPRequestHandler with custom error page handling."""
+
+    def __init__(self, *args, error_page_map: dict[int, pathlib.Path], **kwargs):
+        """Initialize the handler.
+
+        Args:
+            error_page_map: map of error code to error page path
+            *args: passed through to superclass
+            **kwargs: passed through to superclass
+        """
+        self.error_page_map = error_page_map
+        super().__init__(*args, **kwargs)
+
+    def send_error(
+        self, code: int, message: str | None = None, explain: str | None = None
+    ) -> None:
+        """Send the error page for the given error code.
+
+        If the code matches a custom error page, then message and explain are
+        ignored.
+
+        Args:
+            code: the error code
+            message: the error message
+            explain: the error explanation
+        """
+        error_page = self.error_page_map.get(code)
+        if error_page:
+            self.send_response(code, message)
+            self.send_header("Connection", "close")
+            body = error_page.read_bytes()
+            self.send_header("Content-Type", self.error_content_type)
+            self.send_header("Content-Length", str(len(body)))
+            self.end_headers()
+            self.wfile.write(body)
+        else:
+            super().send_error(code, message, explain)
+
+
+class Subdir404TCPServer(socketserver.TCPServer):
+    """TCPServer for SimpleHTTPRequestHandlerCustomErrors that serves from a subdir."""
+
+    def __init__(
+        self,
+        *args,
+        root: pathlib.Path,
+        error_page_map: dict[int, pathlib.Path] | None,
+        **kwargs,
+    ):
+        """Initialize the server.
+
+        Args:
+            root: the root directory to serve from
+            error_page_map: map of error code to error page path
+            *args: passed through to superclass
+            **kwargs: passed through to superclass
+        """
+        self.root = root
+        self.error_page_map = error_page_map or {}
+        super().__init__(*args, **kwargs)
+
+    def finish_request(self, request: socket.socket, client_address: tuple[str, int]):
+        """Finish one request by instantiating RequestHandlerClass.
+
+        Args:
+            request: the requesting socket
+            client_address: (host, port) referring to the client’s address.
+        """
+        print(client_address, type(client_address))
+        self.RequestHandlerClass(
+            request,
+            client_address,
+            self,
+            directory=str(self.root),  # type: ignore
+            error_page_map=self.error_page_map,  # type: ignore
+        )
+
+
+class AppHarnessProd(AppHarness):
+    """AppHarnessProd executes a reflex app in-process for testing.
+
+    In prod mode, instead of running `next dev` the app is exported as static
+    files and served via the builtin python http.server with custom 404 redirect
+    handling. Additionally, the backend runs in multi-worker mode.
+    """
+
+    frontend_thread: Optional[threading.Thread] = None
+    frontend_server: Optional[Subdir404TCPServer] = None
+
+    def _run_frontend(self):
+        web_root = self.app_path / reflex.constants.WEB_DIR / "_static"
+        error_page_map = {
+            404: web_root / "404.html",
+        }
+        with Subdir404TCPServer(
+            ("", 0),
+            SimpleHTTPRequestHandlerCustomErrors,
+            root=web_root,
+            error_page_map=error_page_map,
+        ) as self.frontend_server:
+            self.frontend_url = "http://localhost:{1}".format(
+                *self.frontend_server.socket.getsockname()
+            )
+            self.frontend_server.serve_forever()
+
+    def _start_frontend(self):
+        # Set up the frontend.
+        with chdir(self.app_path):
+            config = reflex.config.get_config()
+            config.api_url = "http://{0}:{1}".format(
+                *self._poll_for_servers().getsockname(),
+            )
+            reflex.reflex.export(
+                zipping=False,
+                frontend=True,
+                backend=False,
+                loglevel=reflex.constants.LogLevel.INFO,
+            )
+
+        self.frontend_thread = threading.Thread(target=self._run_frontend)
+        self.frontend_thread.start()
+
+    def _wait_frontend(self):
+        self._poll_for(lambda: self.frontend_server is not None)
+        if self.frontend_server is None or not self.frontend_server.socket.fileno():
+            raise RuntimeError("Frontend did not start")
+
+    def _start_backend(self):
+        if self.app_instance is None:
+            raise RuntimeError("App was not initialized.")
+        os.environ[reflex.constants.SKIP_COMPILE_ENV_VAR] = "yes"
+        self.backend = uvicorn.Server(
+            uvicorn.Config(
+                app=self.app_instance,
+                host="127.0.0.1",
+                port=0,
+                workers=reflex.utils.processes.get_num_workers(),
+            ),
+        )
+        self.backend_thread = threading.Thread(target=self.backend.run)
+        self.backend_thread.start()
+
+    def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
+        try:
+            return super()._poll_for_servers(timeout)
+        finally:
+            os.environ.pop(reflex.constants.SKIP_COMPILE_ENV_VAR, None)
+
+    def stop(self):
+        """Stop the frontend python webserver."""
+        super().stop()
+        if self.frontend_server is not None:
+            self.frontend_server.shutdown()
+        if self.frontend_thread is not None:
+            self.frontend_thread.join()

+ 9 - 2
reflex/utils/prerequisites.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import glob
+import importlib
 import json
 import os
 import platform
@@ -97,16 +98,22 @@ def get_package_manager() -> str | None:
     return path_ops.get_npm_path()
 
 
-def get_app() -> ModuleType:
+def get_app(reload: bool = False) -> ModuleType:
     """Get the app module based on the default config.
 
+    Args:
+        reload: Re-import the app module from disk
+
     Returns:
         The app based on the default config.
     """
     config = get_config()
     module = ".".join([config.app_name, config.app_name])
     sys.path.insert(0, os.getcwd())
-    return __import__(module, fromlist=(constants.APP_VAR,))
+    app = __import__(module, fromlist=(constants.APP_VAR,))
+    if reload:
+        importlib.reload(app)
+    return app
 
 
 def get_redis() -> Redis | None:

+ 15 - 3
tests/test_app.py

@@ -809,9 +809,17 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
 
     for exp_index, exp_val in enumerate(exp_vals):
+        hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
+        exp_router_data = {
+            "headers": {},
+            "ip": client_ip,
+            "sid": sid,
+            "token": token,
+            **hydrate_event.router_data,
+        }
         update = await process(
             app,
-            event=_event(name=get_hydrate_event(state), val=exp_val),
+            event=hydrate_event,
             sid=sid,
             headers={},
             client_ip=client_ip,
@@ -830,12 +838,16 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                 }
             },
             events=[
-                _dynamic_state_event(name="on_load", val=exp_val, router_data={}),
+                _dynamic_state_event(
+                    name="on_load",
+                    val=exp_val,
+                    router_data=exp_router_data,
+                ),
                 _dynamic_state_event(
                     name="set_is_hydrated",
                     payload={"value": True},
                     val=exp_val,
-                    router_data={},
+                    router_data=exp_router_data,
                 ),
             ],
         )