Browse Source

Catch unhandled errors on both frontend and backend (#3572)

Maxim Vlah 10 months ago
parent
commit
772a3ef893

+ 152 - 0
integration/test_exception_handlers.py

@@ -0,0 +1,152 @@
+"""Integration tests for event exception handlers."""
+
+from __future__ import annotations
+
+import time
+from typing import Generator, Type
+
+import pytest
+from selenium.webdriver.common.by import By
+from selenium.webdriver.remote.webdriver import WebDriver
+from selenium.webdriver.support import expected_conditions as EC
+from selenium.webdriver.support.ui import WebDriverWait
+
+from reflex.testing import AppHarness
+
+
+def TestApp():
+    """A test app for event exception handler integration."""
+    import reflex as rx
+
+    class TestAppConfig(rx.Config):
+        """Config for the TestApp app."""
+
+    class TestAppState(rx.State):
+        """State for the TestApp app."""
+
+        def divide_by_number(self, number: int):
+            """Divide by number and print the result.
+
+            Args:
+                number: number to divide by
+
+            """
+            print(1 / number)
+
+    app = rx.App(state=rx.State)
+
+    @app.add_page
+    def index():
+        return rx.vstack(
+            rx.button(
+                "induce_frontend_error",
+                on_click=rx.call_script("induce_frontend_error()"),
+                id="induce-frontend-error-btn",
+            ),
+            rx.button(
+                "induce_backend_error",
+                on_click=lambda: TestAppState.divide_by_number(0),  # type: ignore
+                id="induce-backend-error-btn",
+            ),
+        )
+
+
+@pytest.fixture(scope="module")
+def test_app(
+    app_harness_env: Type[AppHarness], tmp_path_factory
+) -> Generator[AppHarness, None, None]:
+    """Start TestApp app at tmp_path via AppHarness.
+
+    Args:
+        app_harness_env: either AppHarness (dev) or AppHarnessProd (prod)
+        tmp_path_factory: pytest tmp_path_factory fixture
+
+    Yields:
+        running AppHarness instance
+
+    """
+    with app_harness_env.create(
+        root=tmp_path_factory.mktemp("test_app"),
+        app_name=f"testapp_{app_harness_env.__name__.lower()}",
+        app_source=TestApp,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.fixture
+def driver(test_app: AppHarness) -> Generator[WebDriver, None, None]:
+    """Get an instance of the browser open to the test_app app.
+
+    Args:
+        test_app: harness for TestApp app
+
+    Yields:
+        WebDriver instance.
+
+    """
+    assert test_app.app_instance is not None, "app is not running"
+    driver = test_app.frontend()
+    try:
+        yield driver
+    finally:
+        driver.quit()
+
+
+def test_frontend_exception_handler_during_runtime(
+    driver: WebDriver,
+    capsys,
+):
+    """Test calling frontend exception handler during runtime.
+
+    We send an event containing a call to a non-existent function in the frontend.
+    This should trigger the default frontend exception handler.
+
+    Args:
+        driver: WebDriver instance.
+        capsys: pytest fixture for capturing stdout and stderr.
+
+    """
+    reset_button = WebDriverWait(driver, 20).until(
+        EC.element_to_be_clickable((By.ID, "induce-frontend-error-btn"))
+    )
+
+    reset_button.click()
+
+    # Wait for the error to be logged
+    time.sleep(2)
+
+    captured_default_handler_output = capsys.readouterr()
+    assert (
+        "induce_frontend_error" in captured_default_handler_output.out
+        and "ReferenceError" in captured_default_handler_output.out
+    )
+
+
+def test_backend_exception_handler_during_runtime(
+    driver: WebDriver,
+    capsys,
+):
+    """Test calling backend exception handler during runtime.
+
+    We invoke TestAppState.divide_by_zero to induce backend error.
+    This should trigger the default backend exception handler.
+
+    Args:
+        driver: WebDriver instance.
+        capsys: pytest fixture for capturing stdout and stderr.
+
+    """
+    reset_button = WebDriverWait(driver, 20).until(
+        EC.element_to_be_clickable((By.ID, "induce-backend-error-btn"))
+    )
+
+    reset_button.click()
+
+    # Wait for the error to be logged
+    time.sleep(2)
+
+    captured_default_handler_output = capsys.readouterr()
+    assert (
+        "divide_by_number" in captured_default_handler_output.out
+        and "ZeroDivisionError" in captured_default_handler_output.out
+    )

+ 28 - 0
reflex/.templates/web/utils/state.js

@@ -247,6 +247,9 @@ export const applyEvent = async (event, socket) => {
       }
     } catch (e) {
       console.log("_call_script", e);
+      if (window && window?.onerror) {
+        window.onerror(e.message, null, null, null, e)
+      }
     }
     return false;
   }
@@ -687,6 +690,31 @@ export const useEventLoop = (
     }
   }, [router.isReady]);
 
+    // Handle frontend errors and send them to the backend via websocket.
+    useEffect(() => {
+      
+      if (typeof window === 'undefined') {
+        return;
+      }
+  
+      window.onerror = function (msg, url, lineNo, columnNo, error) {
+        addEvents([Event("state.frontend_event_exception_state.handle_frontend_exception", {
+          stack: error.stack,
+        })])
+        return false;
+      }
+
+      //NOTE: Only works in Chrome v49+
+      //https://github.com/mknichel/javascript-errors?tab=readme-ov-file#promise-rejection-events
+      window.onunhandledrejection = function (event) {
+          addEvents([Event("state.frontend_event_exception_state.handle_frontend_exception", {
+            stack: event.reason.stack,
+          })])
+          return false;
+      }
+  
+    },[])
+
   // Main event loop.
   useEffect(() => {
     // Skip if the router is not ready.

+ 175 - 1
reflex/app.py

@@ -7,11 +7,13 @@ import concurrent.futures
 import contextlib
 import copy
 import functools
+import inspect
 import io
 import multiprocessing
 import os
 import platform
 import sys
+import traceback
 from datetime import datetime
 from typing import (
     Any,
@@ -45,6 +47,7 @@ from reflex.compiler import compiler
 from reflex.compiler import utils as compiler_utils
 from reflex.compiler.compiler import ExecutorSafeFunctions
 from reflex.components.base.app_wrap import AppWrap
+from reflex.components.base.error_boundary import ErrorBoundary
 from reflex.components.base.fragment import Fragment
 from reflex.components.component import (
     Component,
@@ -60,7 +63,7 @@ from reflex.components.core.client_side_routing import (
 from reflex.components.core.upload import Upload, get_upload_dir
 from reflex.components.radix import themes
 from reflex.config import get_config
-from reflex.event import Event, EventHandler, EventSpec
+from reflex.event import Event, EventHandler, EventSpec, window_alert
 from reflex.model import Model
 from reflex.page import (
     DECORATED_PAGES,
@@ -88,6 +91,33 @@ ComponentCallable = Callable[[], Component]
 Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
 
 
+def default_frontend_exception_handler(exception: Exception) -> None:
+    """Default frontend exception handler function.
+
+    Args:
+        exception: The exception.
+
+    """
+    console.error(f"[Reflex Frontend Exception]\n {exception}\n")
+
+
+def default_backend_exception_handler(exception: Exception) -> EventSpec:
+    """Default backend exception handler function.
+
+    Args:
+        exception: The exception.
+
+    Returns:
+        EventSpec: The window alert event.
+
+    """
+    error = traceback.format_exc()
+
+    console.error(f"[Reflex Backend Exception]\n {error}\n")
+
+    return window_alert("An error occurred. See logs for details.")
+
+
 def default_overlay_component() -> Component:
     """Default overlay_component attribute for App.
 
@@ -101,6 +131,16 @@ def default_overlay_component() -> Component:
     )
 
 
+def default_error_boundary() -> Component:
+    """Default error_boundary attribute for App.
+
+    Returns:
+        The default error_boundary, which is an ErrorBoundary.
+
+    """
+    return ErrorBoundary.create()
+
+
 class OverlayFragment(Fragment):
     """Alias for Fragment, used to wrap the overlay_component."""
 
@@ -142,6 +182,11 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         default_overlay_component
     )
 
+    # Error boundary component to wrap the app with.
+    error_boundary: Optional[Union[Component, ComponentCallable]] = (
+        default_error_boundary
+    )
+
     # Components to add to the head of every page.
     head_components: List[Component] = []
 
@@ -178,6 +223,16 @@ class App(MiddlewareMixin, LifespanMixin, Base):
     # Background tasks that are currently running. PRIVATE.
     background_tasks: Set[asyncio.Task] = set()
 
+    # Frontend Error Handler Function
+    frontend_exception_handler: Callable[[Exception], None] = (
+        default_frontend_exception_handler
+    )
+
+    # Backend Error Handler Function
+    backend_exception_handler: Callable[
+        [Exception], Union[EventSpec, List[EventSpec], None]
+    ] = default_backend_exception_handler
+
     def __init__(self, **kwargs):
         """Initialize the app.
 
@@ -279,6 +334,9 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         # Mount the socket app with the API.
         self.api.mount(str(constants.Endpoint.EVENT), socket_app)
 
+        # Check the exception handlers
+        self._validate_exception_handlers()
+
     def __repr__(self) -> str:
         """Get the string representation of the app.
 
@@ -688,6 +746,25 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         for k, component in self.pages.items():
             self.pages[k] = self._add_overlay_to_component(component)
 
+    def _add_error_boundary_to_component(self, component: Component) -> Component:
+        if self.error_boundary is None:
+            return component
+
+        component = ErrorBoundary.create(*component.children)
+
+        return component
+
+    def _setup_error_boundary(self):
+        """If a State is not used and no error_boundary is specified, do not render the error boundary."""
+        if self.state is None and self.error_boundary is default_error_boundary:
+            self.error_boundary = None
+
+        for k, component in self.pages.items():
+            # Skip the 404 page
+            if k == constants.Page404.SLUG:
+                continue
+            self.pages[k] = self._add_error_boundary_to_component(component)
+
     def _apply_decorated_pages(self):
         """Add @rx.page decorated pages to the app.
 
@@ -757,6 +834,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
 
         self._validate_var_dependencies()
         self._setup_overlay_component()
+        self._setup_error_boundary()
 
         # Create a progress bar.
         progress = Progress(
@@ -1036,6 +1114,100 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         task.add_done_callback(self.background_tasks.discard)
         return task
 
+    def _validate_exception_handlers(self):
+        """Validate the custom event exception handlers for front- and backend.
+
+        Raises:
+            ValueError: If the custom exception handlers are invalid.
+
+        """
+        FRONTEND_ARG_SPEC = {
+            "exception": Exception,
+        }
+
+        BACKEND_ARG_SPEC = {
+            "exception": Exception,
+        }
+
+        for handler_domain, handler_fn, handler_spec in zip(
+            ["frontend", "backend"],
+            [self.frontend_exception_handler, self.backend_exception_handler],
+            [
+                FRONTEND_ARG_SPEC,
+                BACKEND_ARG_SPEC,
+            ],
+        ):
+            if hasattr(handler_fn, "__name__"):
+                _fn_name = handler_fn.__name__
+            else:
+                _fn_name = handler_fn.__class__.__name__
+
+            if isinstance(handler_fn, functools.partial):
+                raise ValueError(
+                    f"Provided custom {handler_domain} exception handler `{_fn_name}` is a partial function. Please provide a named function instead."
+                )
+
+            if not callable(handler_fn):
+                raise ValueError(
+                    f"Provided custom {handler_domain} exception handler `{_fn_name}` is not a function."
+                )
+
+            # Allow named functions only as lambda functions cannot be introspected
+            if _fn_name == "<lambda>":
+                raise ValueError(
+                    f"Provided custom {handler_domain} exception handler `{_fn_name}` is a lambda function. Please use a named function instead."
+                )
+
+            # Check if the function has the necessary annotations and types in the right order
+            argspec = inspect.getfullargspec(handler_fn)
+            arg_annotations = {
+                k: eval(v) if isinstance(v, str) else v
+                for k, v in argspec.annotations.items()
+                if k not in ["args", "kwargs", "return"]
+            }
+
+            for required_arg_index, required_arg in enumerate(handler_spec):
+                if required_arg not in arg_annotations:
+                    raise ValueError(
+                        f"Provided custom {handler_domain} exception handler `{_fn_name}` does not take the required argument `{required_arg}`"
+                    )
+                elif (
+                    not list(arg_annotations.keys())[required_arg_index] == required_arg
+                ):
+                    raise ValueError(
+                        f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong argument order."
+                        f"Expected `{required_arg}` as the {required_arg_index+1} argument but got `{list(arg_annotations.keys())[required_arg_index]}`"
+                    )
+
+                if not issubclass(arg_annotations[required_arg], Exception):
+                    raise ValueError(
+                        f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong type for {required_arg} argument."
+                        f"Expected to be `Exception` but got `{arg_annotations[required_arg]}`"
+                    )
+
+            # Check if the return type is valid for backend exception handler
+            if handler_domain == "backend":
+                sig = inspect.signature(self.backend_exception_handler)
+                return_type = (
+                    eval(sig.return_annotation)
+                    if isinstance(sig.return_annotation, str)
+                    else sig.return_annotation
+                )
+
+                valid = bool(
+                    return_type == EventSpec
+                    or return_type == Optional[EventSpec]
+                    or return_type == List[EventSpec]
+                    or return_type == inspect.Signature.empty
+                    or return_type is None
+                )
+
+                if not valid:
+                    raise ValueError(
+                        f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong return type."
+                        f"Expected `Union[EventSpec, List[EventSpec], None]` but got `{return_type}`"
+                    )
+
 
 async def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
@@ -1101,6 +1273,8 @@ async def process(
                     yield update
     except Exception as ex:
         telemetry.send_error(ex, context="backend")
+
+        app.backend_exception_handler(ex)
         raise
 
 

+ 4 - 0
reflex/components/base/__init__.py

@@ -13,6 +13,10 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
         "Fragment",
         "fragment",
     ],
+    "error_boundary": [
+        "ErrorBoundary",
+        "error_boundary",
+    ],
     "head": [
         "head",
         "Head",

+ 2 - 0
reflex/components/base/__init__.pyi

@@ -10,6 +10,8 @@ from .document import DocumentHead as DocumentHead
 from .document import Html as Html
 from .document import Main as Main
 from .document import NextScript as NextScript
+from .error_boundary import ErrorBoundary as ErrorBoundary
+from .error_boundary import error_boundary as error_boundary
 from .fragment import Fragment as Fragment
 from .fragment import fragment as fragment
 from .head import Head as Head

+ 78 - 0
reflex/components/base/error_boundary.py

@@ -0,0 +1,78 @@
+"""A React Error Boundary component that catches unhandled frontend exceptions."""
+
+from __future__ import annotations
+
+from typing import List
+
+from reflex.compiler.compiler import _compile_component
+from reflex.components.component import Component
+from reflex.components.el import div, p
+from reflex.constants import Hooks, Imports
+from reflex.event import EventChain, EventHandler
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
+
+
+class ErrorBoundary(Component):
+    """A React Error Boundary component that catches unhandled frontend exceptions."""
+
+    library = "react-error-boundary"
+    tag = "ErrorBoundary"
+
+    # Fired when the boundary catches an error.
+    on_error: EventHandler[lambda error, info: [error, info]] = Var.create_safe(  # type: ignore
+        "logFrontendError", _var_is_string=False, _var_is_local=False
+    ).to(EventChain)
+
+    # Rendered instead of the children when an error is caught.
+    Fallback_component: Var[Component] = Var.create_safe(
+        "Fallback", _var_is_string=False, _var_is_local=False
+    ).to(Component)
+
+    def add_imports(self) -> dict[str, list[ImportVar]]:
+        """Add imports for the component.
+
+        Returns:
+            The imports to add.
+        """
+        return Imports.EVENTS
+
+    def add_hooks(self) -> List[str | Var]:
+        """Add hooks for the component.
+
+        Returns:
+            The hooks to add.
+        """
+        return [Hooks.EVENTS, Hooks.FRONTEND_ERRORS]
+
+    def add_custom_code(self) -> List[str]:
+        """Add custom Javascript code into the page that contains this component.
+
+        Custom code is inserted at module level, after any imports.
+
+        Returns:
+            The custom code to add.
+        """
+        fallback_container = div(
+            p("Ooops...Unknown Reflex error has occured:"),
+            p(
+                Var.create("error.message", _var_is_local=False, _var_is_string=False),
+                color="red",
+            ),
+            p("Please contact the support."),
+        )
+
+        compiled_fallback = _compile_component(fallback_container)
+
+        return [
+            f"""
+                function Fallback({{ error, resetErrorBoundary }}) {{
+                    return (
+                        {compiled_fallback}
+                    );
+                }}
+            """
+        ]
+
+
+error_boundary = ErrorBoundary.create

+ 98 - 0
reflex/components/base/error_boundary.pyi

@@ -0,0 +1,98 @@
+"""Stub file for reflex/components/base/error_boundary.py"""
+
+# ------------------- DO NOT EDIT ----------------------
+# This file was generated by `reflex/utils/pyi_generator.py`!
+# ------------------------------------------------------
+from typing import Any, Callable, Dict, List, Optional, Union, overload
+
+from reflex.components.component import Component
+from reflex.event import EventHandler, EventSpec
+from reflex.style import Style
+from reflex.utils.imports import ImportVar
+from reflex.vars import BaseVar, Var
+
+class ErrorBoundary(Component):
+    def add_imports(self) -> dict[str, list[ImportVar]]: ...
+    def add_hooks(self) -> List[str | Var]: ...
+    def add_custom_code(self) -> List[str]: ...
+    @overload
+    @classmethod
+    def create(  # type: ignore
+        cls,
+        *children,
+        Fallback_component: Optional[Union[Var[Component], Component]] = None,
+        style: Optional[Style] = None,
+        key: Optional[Any] = None,
+        id: Optional[Any] = None,
+        class_name: Optional[Any] = None,
+        autofocus: Optional[bool] = None,
+        custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
+        on_blur: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_click: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_context_menu: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_double_click: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_error: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_focus: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mount: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_down: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_enter: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_leave: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_move: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_out: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_over: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_up: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_scroll: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_unmount: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        **props,
+    ) -> "ErrorBoundary":
+        """Create the component.
+
+        Args:
+            *children: The children of the component.
+            Fallback_component: Rendered instead of the children when an error is caught.
+            style: The style of the component.
+            key: A unique key for the component.
+            id: The id for the component.
+            class_name: The class name for the component.
+            autofocus: Whether the component should take the focus once the page is loaded
+            custom_attrs: custom attribute
+            **props: The props of the component.
+
+        Returns:
+            The component.
+        """
+        ...
+
+error_boundary = ErrorBoundary.create

+ 10 - 0
reflex/constants/compiler.py

@@ -124,6 +124,16 @@ class Hooks(SimpleNamespace):
                   }
                 })"""
 
+    FRONTEND_ERRORS = """
+    const logFrontendError = (error, info) => {
+        if (process.env.NODE_ENV === "production") {
+            addEvents([Event("frontend_event_exception_state.handle_frontend_exception", {
+                stack: error.stack,
+            })])
+        }
+    }
+    """
+
 
 class MemoizationDisposition(enum.Enum):
     """The conditions under which a component should be memoized."""

+ 57 - 13
reflex/state.py

@@ -8,7 +8,6 @@ import copy
 import functools
 import inspect
 import os
-import traceback
 import uuid
 from abc import ABC, abstractmethod
 from collections import defaultdict
@@ -25,6 +24,8 @@ from typing import (
     Sequence,
     Set,
     Type,
+    Union,
+    cast,
 )
 
 import dill
@@ -47,7 +48,6 @@ from reflex.event import (
     EventHandler,
     EventSpec,
     fix_events,
-    window_alert,
 )
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
@@ -1430,15 +1430,39 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Convert valid EventHandler and EventSpec into Event
         fixed_events = fix_events(self._check_valid(handler, events), token)
 
-        # Get the delta after processing the event.
-        delta = state.get_delta()
-        state._clean()
+        try:
+            # Get the delta after processing the event.
+            delta = state.get_delta()
+            state._clean()
+
+            return StateUpdate(
+                delta=delta,
+                events=fixed_events,
+                final=final if not handler.is_background else True,
+            )
+        except Exception as ex:
+            state._clean()
 
-        return StateUpdate(
-            delta=delta,
-            events=fixed_events,
-            final=final if not handler.is_background else True,
-        )
+            app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+
+            event_specs = app_instance.backend_exception_handler(ex)
+
+            if event_specs is None:
+                return StateUpdate()
+
+            event_specs_correct_type = cast(
+                Union[List[Union[EventSpec, EventHandler]], None],
+                [event_specs] if isinstance(event_specs, EventSpec) else event_specs,
+            )
+            fixed_events = fix_events(
+                event_specs_correct_type,
+                token,
+                router_data=state.router_data,
+            )
+            return StateUpdate(
+                events=fixed_events,
+                final=True,
+            )
 
     async def _process_event(
         self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
@@ -1491,12 +1515,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         # If an error occurs, throw a window alert.
         except Exception as ex:
-            error = traceback.format_exc()
-            print(error)
             telemetry.send_error(ex, context="backend")
+
+            app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+
+            event_specs = app_instance.backend_exception_handler(ex)
+
             yield state._as_state_update(
                 handler,
-                window_alert("An error occurred. See logs for details."),
+                event_specs,
                 final=True,
             )
 
@@ -1798,6 +1825,23 @@ class State(BaseState):
     is_hydrated: bool = False
 
 
+class FrontendEventExceptionState(State):
+    """Substate for handling frontend exceptions."""
+
+    def handle_frontend_exception(self, stack: str) -> None:
+        """Handle frontend exceptions.
+
+        If a frontend exception handler is provided, it will be called.
+        Otherwise, the default frontend exception handler will be called.
+
+        Args:
+            stack: The stack trace of the exception.
+
+        """
+        app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+        app_instance.frontend_exception_handler(Exception(stack))
+
+
 class UpdateVarsInternalState(State):
     """Substate for handling internal state var updates."""
 

+ 164 - 0
tests/test_app.py

@@ -1,11 +1,13 @@
 from __future__ import annotations
 
+import functools
 import io
 import json
 import os.path
 import re
 import unittest.mock
 import uuid
+from contextlib import nullcontext as does_not_raise
 from pathlib import Path
 from typing import Generator, List, Tuple, Type
 from unittest.mock import AsyncMock
@@ -1571,3 +1573,165 @@ def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
     app.state = InvalidDepState
     with pytest.raises(exceptions.VarDependencyError):
         app._compile()
+
+
+# Test custom exception handlers
+
+
+def valid_custom_handler(exception: Exception, logger: str = "test"):
+    print("Custom Backend Exception")
+    print(exception)
+
+
+def custom_exception_handler_with_wrong_arg_order(
+    logger: str,
+    exception: Exception,  # Should be first
+):
+    print("Custom Backend Exception")
+    print(exception)
+
+
+def custom_exception_handler_with_wrong_argspec(
+    exception: str,  # Should be Exception
+):
+    print("Custom Backend Exception")
+    print(exception)
+
+
+class DummyExceptionHandler:
+    """Dummy exception handler class."""
+
+    def handle(self, exception: Exception):
+        """Handle the exception.
+
+        Args:
+            exception: The exception.
+
+        """
+        print("Custom Backend Exception")
+        print(exception)
+
+
+custom_exception_handlers = {
+    "lambda": lambda exception: print("Custom Exception Handler", exception),
+    "wrong_argspec": custom_exception_handler_with_wrong_argspec,
+    "wrong_arg_order": custom_exception_handler_with_wrong_arg_order,
+    "valid": valid_custom_handler,
+    "partial": functools.partial(valid_custom_handler, logger="test"),
+    "method": DummyExceptionHandler().handle,
+}
+
+
+@pytest.mark.parametrize(
+    "handler_fn, expected",
+    [
+        pytest.param(
+            custom_exception_handlers["partial"],
+            pytest.raises(ValueError),
+            id="partial",
+        ),
+        pytest.param(
+            custom_exception_handlers["lambda"],
+            pytest.raises(ValueError),
+            id="lambda",
+        ),
+        pytest.param(
+            custom_exception_handlers["wrong_argspec"],
+            pytest.raises(ValueError),
+            id="wrong_argspec",
+        ),
+        pytest.param(
+            custom_exception_handlers["wrong_arg_order"],
+            pytest.raises(ValueError),
+            id="wrong_arg_order",
+        ),
+        pytest.param(
+            custom_exception_handlers["valid"],
+            does_not_raise(),
+            id="valid_handler",
+        ),
+        pytest.param(
+            custom_exception_handlers["method"],
+            does_not_raise(),
+            id="valid_class_method",
+        ),
+    ],
+)
+def test_frontend_exception_handler_validation(handler_fn, expected):
+    """Test that the custom frontend exception handler is properly validated.
+
+    Args:
+        handler_fn: The handler function.
+        expected: The expected result.
+
+    """
+    with expected:
+        rx.App(frontend_exception_handler=handler_fn)._validate_exception_handlers()
+
+
+def backend_exception_handler_with_wrong_return_type(exception: Exception) -> int:
+    """Custom backend exception handler with wrong return type.
+
+    Args:
+        exception: The exception.
+
+    Returns:
+        int: The wrong return type.
+
+    """
+    print("Custom Backend Exception")
+    print(exception)
+
+    return 5
+
+
+@pytest.mark.parametrize(
+    "handler_fn, expected",
+    [
+        pytest.param(
+            backend_exception_handler_with_wrong_return_type,
+            pytest.raises(ValueError),
+            id="wrong_return_type",
+        ),
+        pytest.param(
+            custom_exception_handlers["partial"],
+            pytest.raises(ValueError),
+            id="partial",
+        ),
+        pytest.param(
+            custom_exception_handlers["lambda"],
+            pytest.raises(ValueError),
+            id="lambda",
+        ),
+        pytest.param(
+            custom_exception_handlers["wrong_argspec"],
+            pytest.raises(ValueError),
+            id="wrong_argspec",
+        ),
+        pytest.param(
+            custom_exception_handlers["wrong_arg_order"],
+            pytest.raises(ValueError),
+            id="wrong_arg_order",
+        ),
+        pytest.param(
+            custom_exception_handlers["valid"],
+            does_not_raise(),
+            id="valid_handler",
+        ),
+        pytest.param(
+            custom_exception_handlers["method"],
+            does_not_raise(),
+            id="valid_class_method",
+        ),
+    ],
+)
+def test_backend_exception_handler_validation(handler_fn, expected):
+    """Test that the custom backend exception handler is properly validated.
+
+    Args:
+        handler_fn: The handler function.
+        expected: The expected result.
+
+    """
+    with expected:
+        rx.App(backend_exception_handler=handler_fn)._validate_exception_handlers()

+ 3 - 1
tests/test_state.py

@@ -1464,11 +1464,13 @@ def test_error_on_state_method_shadow():
 
 
 @pytest.mark.asyncio
-async def test_state_with_invalid_yield(capsys):
+async def test_state_with_invalid_yield(capsys, mock_app):
     """Test that an error is thrown when a state yields an invalid value.
 
     Args:
         capsys: Pytest fixture for capture standard streams.
+        mock_app: Mock app fixture.
+
     """
 
     class StateWithInvalidYield(BaseState):