Bläddra i källkod

Catch unhandled errors on both frontend and backend (#3572)

Maxim Vlah 10 månader sedan
förälder
incheckning
772a3ef893

+ 152 - 0
integration/test_exception_handlers.py

@@ -0,0 +1,152 @@
+"""Integration tests for event exception handlers."""
+
+from __future__ import annotations
+
+import time
+from typing import Generator, Type
+
+import pytest
+from selenium.webdriver.common.by import By
+from selenium.webdriver.remote.webdriver import WebDriver
+from selenium.webdriver.support import expected_conditions as EC
+from selenium.webdriver.support.ui import WebDriverWait
+
+from reflex.testing import AppHarness
+
+
+def TestApp():
+    """A test app for event exception handler integration."""
+    import reflex as rx
+
+    class TestAppConfig(rx.Config):
+        """Config for the TestApp app."""
+
+    class TestAppState(rx.State):
+        """State for the TestApp app."""
+
+        def divide_by_number(self, number: int):
+            """Divide by number and print the result.
+
+            Args:
+                number: number to divide by
+
+            """
+            print(1 / number)
+
+    app = rx.App(state=rx.State)
+
+    @app.add_page
+    def index():
+        return rx.vstack(
+            rx.button(
+                "induce_frontend_error",
+                on_click=rx.call_script("induce_frontend_error()"),
+                id="induce-frontend-error-btn",
+            ),
+            rx.button(
+                "induce_backend_error",
+                on_click=lambda: TestAppState.divide_by_number(0),  # type: ignore
+                id="induce-backend-error-btn",
+            ),
+        )
+
+
+@pytest.fixture(scope="module")
+def test_app(
+    app_harness_env: Type[AppHarness], tmp_path_factory
+) -> Generator[AppHarness, None, None]:
+    """Start TestApp app at tmp_path via AppHarness.
+
+    Args:
+        app_harness_env: either AppHarness (dev) or AppHarnessProd (prod)
+        tmp_path_factory: pytest tmp_path_factory fixture
+
+    Yields:
+        running AppHarness instance
+
+    """
+    with app_harness_env.create(
+        root=tmp_path_factory.mktemp("test_app"),
+        app_name=f"testapp_{app_harness_env.__name__.lower()}",
+        app_source=TestApp,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.fixture
+def driver(test_app: AppHarness) -> Generator[WebDriver, None, None]:
+    """Get an instance of the browser open to the test_app app.
+
+    Args:
+        test_app: harness for TestApp app
+
+    Yields:
+        WebDriver instance.
+
+    """
+    assert test_app.app_instance is not None, "app is not running"
+    driver = test_app.frontend()
+    try:
+        yield driver
+    finally:
+        driver.quit()
+
+
+def test_frontend_exception_handler_during_runtime(
+    driver: WebDriver,
+    capsys,
+):
+    """Test calling frontend exception handler during runtime.
+
+    We send an event containing a call to a non-existent function in the frontend.
+    This should trigger the default frontend exception handler.
+
+    Args:
+        driver: WebDriver instance.
+        capsys: pytest fixture for capturing stdout and stderr.
+
+    """
+    reset_button = WebDriverWait(driver, 20).until(
+        EC.element_to_be_clickable((By.ID, "induce-frontend-error-btn"))
+    )
+
+    reset_button.click()
+
+    # Wait for the error to be logged
+    time.sleep(2)
+
+    captured_default_handler_output = capsys.readouterr()
+    assert (
+        "induce_frontend_error" in captured_default_handler_output.out
+        and "ReferenceError" in captured_default_handler_output.out
+    )
+
+
+def test_backend_exception_handler_during_runtime(
+    driver: WebDriver,
+    capsys,
+):
+    """Test calling backend exception handler during runtime.
+
+    We invoke TestAppState.divide_by_zero to induce backend error.
+    This should trigger the default backend exception handler.
+
+    Args:
+        driver: WebDriver instance.
+        capsys: pytest fixture for capturing stdout and stderr.
+
+    """
+    reset_button = WebDriverWait(driver, 20).until(
+        EC.element_to_be_clickable((By.ID, "induce-backend-error-btn"))
+    )
+
+    reset_button.click()
+
+    # Wait for the error to be logged
+    time.sleep(2)
+
+    captured_default_handler_output = capsys.readouterr()
+    assert (
+        "divide_by_number" in captured_default_handler_output.out
+        and "ZeroDivisionError" in captured_default_handler_output.out
+    )

+ 28 - 0
reflex/.templates/web/utils/state.js

@@ -247,6 +247,9 @@ export const applyEvent = async (event, socket) => {
       }
       }
     } catch (e) {
     } catch (e) {
       console.log("_call_script", e);
       console.log("_call_script", e);
+      if (window && window?.onerror) {
+        window.onerror(e.message, null, null, null, e)
+      }
     }
     }
     return false;
     return false;
   }
   }
@@ -687,6 +690,31 @@ export const useEventLoop = (
     }
     }
   }, [router.isReady]);
   }, [router.isReady]);
 
 
+    // Handle frontend errors and send them to the backend via websocket.
+    useEffect(() => {
+      
+      if (typeof window === 'undefined') {
+        return;
+      }
+  
+      window.onerror = function (msg, url, lineNo, columnNo, error) {
+        addEvents([Event("state.frontend_event_exception_state.handle_frontend_exception", {
+          stack: error.stack,
+        })])
+        return false;
+      }
+
+      //NOTE: Only works in Chrome v49+
+      //https://github.com/mknichel/javascript-errors?tab=readme-ov-file#promise-rejection-events
+      window.onunhandledrejection = function (event) {
+          addEvents([Event("state.frontend_event_exception_state.handle_frontend_exception", {
+            stack: event.reason.stack,
+          })])
+          return false;
+      }
+  
+    },[])
+
   // Main event loop.
   // Main event loop.
   useEffect(() => {
   useEffect(() => {
     // Skip if the router is not ready.
     // Skip if the router is not ready.

+ 175 - 1
reflex/app.py

@@ -7,11 +7,13 @@ import concurrent.futures
 import contextlib
 import contextlib
 import copy
 import copy
 import functools
 import functools
+import inspect
 import io
 import io
 import multiprocessing
 import multiprocessing
 import os
 import os
 import platform
 import platform
 import sys
 import sys
+import traceback
 from datetime import datetime
 from datetime import datetime
 from typing import (
 from typing import (
     Any,
     Any,
@@ -45,6 +47,7 @@ from reflex.compiler import compiler
 from reflex.compiler import utils as compiler_utils
 from reflex.compiler import utils as compiler_utils
 from reflex.compiler.compiler import ExecutorSafeFunctions
 from reflex.compiler.compiler import ExecutorSafeFunctions
 from reflex.components.base.app_wrap import AppWrap
 from reflex.components.base.app_wrap import AppWrap
+from reflex.components.base.error_boundary import ErrorBoundary
 from reflex.components.base.fragment import Fragment
 from reflex.components.base.fragment import Fragment
 from reflex.components.component import (
 from reflex.components.component import (
     Component,
     Component,
@@ -60,7 +63,7 @@ from reflex.components.core.client_side_routing import (
 from reflex.components.core.upload import Upload, get_upload_dir
 from reflex.components.core.upload import Upload, get_upload_dir
 from reflex.components.radix import themes
 from reflex.components.radix import themes
 from reflex.config import get_config
 from reflex.config import get_config
-from reflex.event import Event, EventHandler, EventSpec
+from reflex.event import Event, EventHandler, EventSpec, window_alert
 from reflex.model import Model
 from reflex.model import Model
 from reflex.page import (
 from reflex.page import (
     DECORATED_PAGES,
     DECORATED_PAGES,
@@ -88,6 +91,33 @@ ComponentCallable = Callable[[], Component]
 Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
 Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
 
 
 
 
+def default_frontend_exception_handler(exception: Exception) -> None:
+    """Default frontend exception handler function.
+
+    Args:
+        exception: The exception.
+
+    """
+    console.error(f"[Reflex Frontend Exception]\n {exception}\n")
+
+
+def default_backend_exception_handler(exception: Exception) -> EventSpec:
+    """Default backend exception handler function.
+
+    Args:
+        exception: The exception.
+
+    Returns:
+        EventSpec: The window alert event.
+
+    """
+    error = traceback.format_exc()
+
+    console.error(f"[Reflex Backend Exception]\n {error}\n")
+
+    return window_alert("An error occurred. See logs for details.")
+
+
 def default_overlay_component() -> Component:
 def default_overlay_component() -> Component:
     """Default overlay_component attribute for App.
     """Default overlay_component attribute for App.
 
 
@@ -101,6 +131,16 @@ def default_overlay_component() -> Component:
     )
     )
 
 
 
 
+def default_error_boundary() -> Component:
+    """Default error_boundary attribute for App.
+
+    Returns:
+        The default error_boundary, which is an ErrorBoundary.
+
+    """
+    return ErrorBoundary.create()
+
+
 class OverlayFragment(Fragment):
 class OverlayFragment(Fragment):
     """Alias for Fragment, used to wrap the overlay_component."""
     """Alias for Fragment, used to wrap the overlay_component."""
 
 
@@ -142,6 +182,11 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         default_overlay_component
         default_overlay_component
     )
     )
 
 
+    # Error boundary component to wrap the app with.
+    error_boundary: Optional[Union[Component, ComponentCallable]] = (
+        default_error_boundary
+    )
+
     # Components to add to the head of every page.
     # Components to add to the head of every page.
     head_components: List[Component] = []
     head_components: List[Component] = []
 
 
@@ -178,6 +223,16 @@ class App(MiddlewareMixin, LifespanMixin, Base):
     # Background tasks that are currently running. PRIVATE.
     # Background tasks that are currently running. PRIVATE.
     background_tasks: Set[asyncio.Task] = set()
     background_tasks: Set[asyncio.Task] = set()
 
 
+    # Frontend Error Handler Function
+    frontend_exception_handler: Callable[[Exception], None] = (
+        default_frontend_exception_handler
+    )
+
+    # Backend Error Handler Function
+    backend_exception_handler: Callable[
+        [Exception], Union[EventSpec, List[EventSpec], None]
+    ] = default_backend_exception_handler
+
     def __init__(self, **kwargs):
     def __init__(self, **kwargs):
         """Initialize the app.
         """Initialize the app.
 
 
@@ -279,6 +334,9 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         # Mount the socket app with the API.
         # Mount the socket app with the API.
         self.api.mount(str(constants.Endpoint.EVENT), socket_app)
         self.api.mount(str(constants.Endpoint.EVENT), socket_app)
 
 
+        # Check the exception handlers
+        self._validate_exception_handlers()
+
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         """Get the string representation of the app.
         """Get the string representation of the app.
 
 
@@ -688,6 +746,25 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         for k, component in self.pages.items():
         for k, component in self.pages.items():
             self.pages[k] = self._add_overlay_to_component(component)
             self.pages[k] = self._add_overlay_to_component(component)
 
 
+    def _add_error_boundary_to_component(self, component: Component) -> Component:
+        if self.error_boundary is None:
+            return component
+
+        component = ErrorBoundary.create(*component.children)
+
+        return component
+
+    def _setup_error_boundary(self):
+        """If a State is not used and no error_boundary is specified, do not render the error boundary."""
+        if self.state is None and self.error_boundary is default_error_boundary:
+            self.error_boundary = None
+
+        for k, component in self.pages.items():
+            # Skip the 404 page
+            if k == constants.Page404.SLUG:
+                continue
+            self.pages[k] = self._add_error_boundary_to_component(component)
+
     def _apply_decorated_pages(self):
     def _apply_decorated_pages(self):
         """Add @rx.page decorated pages to the app.
         """Add @rx.page decorated pages to the app.
 
 
@@ -757,6 +834,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
 
 
         self._validate_var_dependencies()
         self._validate_var_dependencies()
         self._setup_overlay_component()
         self._setup_overlay_component()
+        self._setup_error_boundary()
 
 
         # Create a progress bar.
         # Create a progress bar.
         progress = Progress(
         progress = Progress(
@@ -1036,6 +1114,100 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         task.add_done_callback(self.background_tasks.discard)
         task.add_done_callback(self.background_tasks.discard)
         return task
         return task
 
 
+    def _validate_exception_handlers(self):
+        """Validate the custom event exception handlers for front- and backend.
+
+        Raises:
+            ValueError: If the custom exception handlers are invalid.
+
+        """
+        FRONTEND_ARG_SPEC = {
+            "exception": Exception,
+        }
+
+        BACKEND_ARG_SPEC = {
+            "exception": Exception,
+        }
+
+        for handler_domain, handler_fn, handler_spec in zip(
+            ["frontend", "backend"],
+            [self.frontend_exception_handler, self.backend_exception_handler],
+            [
+                FRONTEND_ARG_SPEC,
+                BACKEND_ARG_SPEC,
+            ],
+        ):
+            if hasattr(handler_fn, "__name__"):
+                _fn_name = handler_fn.__name__
+            else:
+                _fn_name = handler_fn.__class__.__name__
+
+            if isinstance(handler_fn, functools.partial):
+                raise ValueError(
+                    f"Provided custom {handler_domain} exception handler `{_fn_name}` is a partial function. Please provide a named function instead."
+                )
+
+            if not callable(handler_fn):
+                raise ValueError(
+                    f"Provided custom {handler_domain} exception handler `{_fn_name}` is not a function."
+                )
+
+            # Allow named functions only as lambda functions cannot be introspected
+            if _fn_name == "<lambda>":
+                raise ValueError(
+                    f"Provided custom {handler_domain} exception handler `{_fn_name}` is a lambda function. Please use a named function instead."
+                )
+
+            # Check if the function has the necessary annotations and types in the right order
+            argspec = inspect.getfullargspec(handler_fn)
+            arg_annotations = {
+                k: eval(v) if isinstance(v, str) else v
+                for k, v in argspec.annotations.items()
+                if k not in ["args", "kwargs", "return"]
+            }
+
+            for required_arg_index, required_arg in enumerate(handler_spec):
+                if required_arg not in arg_annotations:
+                    raise ValueError(
+                        f"Provided custom {handler_domain} exception handler `{_fn_name}` does not take the required argument `{required_arg}`"
+                    )
+                elif (
+                    not list(arg_annotations.keys())[required_arg_index] == required_arg
+                ):
+                    raise ValueError(
+                        f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong argument order."
+                        f"Expected `{required_arg}` as the {required_arg_index+1} argument but got `{list(arg_annotations.keys())[required_arg_index]}`"
+                    )
+
+                if not issubclass(arg_annotations[required_arg], Exception):
+                    raise ValueError(
+                        f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong type for {required_arg} argument."
+                        f"Expected to be `Exception` but got `{arg_annotations[required_arg]}`"
+                    )
+
+            # Check if the return type is valid for backend exception handler
+            if handler_domain == "backend":
+                sig = inspect.signature(self.backend_exception_handler)
+                return_type = (
+                    eval(sig.return_annotation)
+                    if isinstance(sig.return_annotation, str)
+                    else sig.return_annotation
+                )
+
+                valid = bool(
+                    return_type == EventSpec
+                    or return_type == Optional[EventSpec]
+                    or return_type == List[EventSpec]
+                    or return_type == inspect.Signature.empty
+                    or return_type is None
+                )
+
+                if not valid:
+                    raise ValueError(
+                        f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong return type."
+                        f"Expected `Union[EventSpec, List[EventSpec], None]` but got `{return_type}`"
+                    )
+
 
 
 async def process(
 async def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
@@ -1101,6 +1273,8 @@ async def process(
                     yield update
                     yield update
     except Exception as ex:
     except Exception as ex:
         telemetry.send_error(ex, context="backend")
         telemetry.send_error(ex, context="backend")
+
+        app.backend_exception_handler(ex)
         raise
         raise
 
 
 
 

+ 4 - 0
reflex/components/base/__init__.py

@@ -13,6 +13,10 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
         "Fragment",
         "Fragment",
         "fragment",
         "fragment",
     ],
     ],
+    "error_boundary": [
+        "ErrorBoundary",
+        "error_boundary",
+    ],
     "head": [
     "head": [
         "head",
         "head",
         "Head",
         "Head",

+ 2 - 0
reflex/components/base/__init__.pyi

@@ -10,6 +10,8 @@ from .document import DocumentHead as DocumentHead
 from .document import Html as Html
 from .document import Html as Html
 from .document import Main as Main
 from .document import Main as Main
 from .document import NextScript as NextScript
 from .document import NextScript as NextScript
+from .error_boundary import ErrorBoundary as ErrorBoundary
+from .error_boundary import error_boundary as error_boundary
 from .fragment import Fragment as Fragment
 from .fragment import Fragment as Fragment
 from .fragment import fragment as fragment
 from .fragment import fragment as fragment
 from .head import Head as Head
 from .head import Head as Head

+ 78 - 0
reflex/components/base/error_boundary.py

@@ -0,0 +1,78 @@
+"""A React Error Boundary component that catches unhandled frontend exceptions."""
+
+from __future__ import annotations
+
+from typing import List
+
+from reflex.compiler.compiler import _compile_component
+from reflex.components.component import Component
+from reflex.components.el import div, p
+from reflex.constants import Hooks, Imports
+from reflex.event import EventChain, EventHandler
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
+
+
+class ErrorBoundary(Component):
+    """A React Error Boundary component that catches unhandled frontend exceptions."""
+
+    library = "react-error-boundary"
+    tag = "ErrorBoundary"
+
+    # Fired when the boundary catches an error.
+    on_error: EventHandler[lambda error, info: [error, info]] = Var.create_safe(  # type: ignore
+        "logFrontendError", _var_is_string=False, _var_is_local=False
+    ).to(EventChain)
+
+    # Rendered instead of the children when an error is caught.
+    Fallback_component: Var[Component] = Var.create_safe(
+        "Fallback", _var_is_string=False, _var_is_local=False
+    ).to(Component)
+
+    def add_imports(self) -> dict[str, list[ImportVar]]:
+        """Add imports for the component.
+
+        Returns:
+            The imports to add.
+        """
+        return Imports.EVENTS
+
+    def add_hooks(self) -> List[str | Var]:
+        """Add hooks for the component.
+
+        Returns:
+            The hooks to add.
+        """
+        return [Hooks.EVENTS, Hooks.FRONTEND_ERRORS]
+
+    def add_custom_code(self) -> List[str]:
+        """Add custom Javascript code into the page that contains this component.
+
+        Custom code is inserted at module level, after any imports.
+
+        Returns:
+            The custom code to add.
+        """
+        fallback_container = div(
+            p("Ooops...Unknown Reflex error has occured:"),
+            p(
+                Var.create("error.message", _var_is_local=False, _var_is_string=False),
+                color="red",
+            ),
+            p("Please contact the support."),
+        )
+
+        compiled_fallback = _compile_component(fallback_container)
+
+        return [
+            f"""
+                function Fallback({{ error, resetErrorBoundary }}) {{
+                    return (
+                        {compiled_fallback}
+                    );
+                }}
+            """
+        ]
+
+
+error_boundary = ErrorBoundary.create

+ 98 - 0
reflex/components/base/error_boundary.pyi

@@ -0,0 +1,98 @@
+"""Stub file for reflex/components/base/error_boundary.py"""
+
+# ------------------- DO NOT EDIT ----------------------
+# This file was generated by `reflex/utils/pyi_generator.py`!
+# ------------------------------------------------------
+from typing import Any, Callable, Dict, List, Optional, Union, overload
+
+from reflex.components.component import Component
+from reflex.event import EventHandler, EventSpec
+from reflex.style import Style
+from reflex.utils.imports import ImportVar
+from reflex.vars import BaseVar, Var
+
+class ErrorBoundary(Component):
+    def add_imports(self) -> dict[str, list[ImportVar]]: ...
+    def add_hooks(self) -> List[str | Var]: ...
+    def add_custom_code(self) -> List[str]: ...
+    @overload
+    @classmethod
+    def create(  # type: ignore
+        cls,
+        *children,
+        Fallback_component: Optional[Union[Var[Component], Component]] = None,
+        style: Optional[Style] = None,
+        key: Optional[Any] = None,
+        id: Optional[Any] = None,
+        class_name: Optional[Any] = None,
+        autofocus: Optional[bool] = None,
+        custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
+        on_blur: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_click: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_context_menu: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_double_click: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_error: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_focus: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mount: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_down: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_enter: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_leave: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_move: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_out: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_over: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_mouse_up: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_scroll: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        on_unmount: Optional[
+            Union[EventHandler, EventSpec, list, Callable, BaseVar]
+        ] = None,
+        **props,
+    ) -> "ErrorBoundary":
+        """Create the component.
+
+        Args:
+            *children: The children of the component.
+            Fallback_component: Rendered instead of the children when an error is caught.
+            style: The style of the component.
+            key: A unique key for the component.
+            id: The id for the component.
+            class_name: The class name for the component.
+            autofocus: Whether the component should take the focus once the page is loaded
+            custom_attrs: custom attribute
+            **props: The props of the component.
+
+        Returns:
+            The component.
+        """
+        ...
+
+error_boundary = ErrorBoundary.create

+ 10 - 0
reflex/constants/compiler.py

@@ -124,6 +124,16 @@ class Hooks(SimpleNamespace):
                   }
                   }
                 })"""
                 })"""
 
 
+    FRONTEND_ERRORS = """
+    const logFrontendError = (error, info) => {
+        if (process.env.NODE_ENV === "production") {
+            addEvents([Event("frontend_event_exception_state.handle_frontend_exception", {
+                stack: error.stack,
+            })])
+        }
+    }
+    """
+
 
 
 class MemoizationDisposition(enum.Enum):
 class MemoizationDisposition(enum.Enum):
     """The conditions under which a component should be memoized."""
     """The conditions under which a component should be memoized."""

+ 57 - 13
reflex/state.py

@@ -8,7 +8,6 @@ import copy
 import functools
 import functools
 import inspect
 import inspect
 import os
 import os
-import traceback
 import uuid
 import uuid
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from collections import defaultdict
 from collections import defaultdict
@@ -25,6 +24,8 @@ from typing import (
     Sequence,
     Sequence,
     Set,
     Set,
     Type,
     Type,
+    Union,
+    cast,
 )
 )
 
 
 import dill
 import dill
@@ -47,7 +48,6 @@ from reflex.event import (
     EventHandler,
     EventHandler,
     EventSpec,
     EventSpec,
     fix_events,
     fix_events,
-    window_alert,
 )
 )
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
 from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
@@ -1430,15 +1430,39 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Convert valid EventHandler and EventSpec into Event
         # Convert valid EventHandler and EventSpec into Event
         fixed_events = fix_events(self._check_valid(handler, events), token)
         fixed_events = fix_events(self._check_valid(handler, events), token)
 
 
-        # Get the delta after processing the event.
-        delta = state.get_delta()
-        state._clean()
+        try:
+            # Get the delta after processing the event.
+            delta = state.get_delta()
+            state._clean()
+
+            return StateUpdate(
+                delta=delta,
+                events=fixed_events,
+                final=final if not handler.is_background else True,
+            )
+        except Exception as ex:
+            state._clean()
 
 
-        return StateUpdate(
-            delta=delta,
-            events=fixed_events,
-            final=final if not handler.is_background else True,
-        )
+            app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+
+            event_specs = app_instance.backend_exception_handler(ex)
+
+            if event_specs is None:
+                return StateUpdate()
+
+            event_specs_correct_type = cast(
+                Union[List[Union[EventSpec, EventHandler]], None],
+                [event_specs] if isinstance(event_specs, EventSpec) else event_specs,
+            )
+            fixed_events = fix_events(
+                event_specs_correct_type,
+                token,
+                router_data=state.router_data,
+            )
+            return StateUpdate(
+                events=fixed_events,
+                final=True,
+            )
 
 
     async def _process_event(
     async def _process_event(
         self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
         self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
@@ -1491,12 +1515,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         # If an error occurs, throw a window alert.
         # If an error occurs, throw a window alert.
         except Exception as ex:
         except Exception as ex:
-            error = traceback.format_exc()
-            print(error)
             telemetry.send_error(ex, context="backend")
             telemetry.send_error(ex, context="backend")
+
+            app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+
+            event_specs = app_instance.backend_exception_handler(ex)
+
             yield state._as_state_update(
             yield state._as_state_update(
                 handler,
                 handler,
-                window_alert("An error occurred. See logs for details."),
+                event_specs,
                 final=True,
                 final=True,
             )
             )
 
 
@@ -1798,6 +1825,23 @@ class State(BaseState):
     is_hydrated: bool = False
     is_hydrated: bool = False
 
 
 
 
+class FrontendEventExceptionState(State):
+    """Substate for handling frontend exceptions."""
+
+    def handle_frontend_exception(self, stack: str) -> None:
+        """Handle frontend exceptions.
+
+        If a frontend exception handler is provided, it will be called.
+        Otherwise, the default frontend exception handler will be called.
+
+        Args:
+            stack: The stack trace of the exception.
+
+        """
+        app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+        app_instance.frontend_exception_handler(Exception(stack))
+
+
 class UpdateVarsInternalState(State):
 class UpdateVarsInternalState(State):
     """Substate for handling internal state var updates."""
     """Substate for handling internal state var updates."""
 
 

+ 164 - 0
tests/test_app.py

@@ -1,11 +1,13 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
+import functools
 import io
 import io
 import json
 import json
 import os.path
 import os.path
 import re
 import re
 import unittest.mock
 import unittest.mock
 import uuid
 import uuid
+from contextlib import nullcontext as does_not_raise
 from pathlib import Path
 from pathlib import Path
 from typing import Generator, List, Tuple, Type
 from typing import Generator, List, Tuple, Type
 from unittest.mock import AsyncMock
 from unittest.mock import AsyncMock
@@ -1571,3 +1573,165 @@ def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
     app.state = InvalidDepState
     app.state = InvalidDepState
     with pytest.raises(exceptions.VarDependencyError):
     with pytest.raises(exceptions.VarDependencyError):
         app._compile()
         app._compile()
+
+
+# Test custom exception handlers
+
+
+def valid_custom_handler(exception: Exception, logger: str = "test"):
+    print("Custom Backend Exception")
+    print(exception)
+
+
+def custom_exception_handler_with_wrong_arg_order(
+    logger: str,
+    exception: Exception,  # Should be first
+):
+    print("Custom Backend Exception")
+    print(exception)
+
+
+def custom_exception_handler_with_wrong_argspec(
+    exception: str,  # Should be Exception
+):
+    print("Custom Backend Exception")
+    print(exception)
+
+
+class DummyExceptionHandler:
+    """Dummy exception handler class."""
+
+    def handle(self, exception: Exception):
+        """Handle the exception.
+
+        Args:
+            exception: The exception.
+
+        """
+        print("Custom Backend Exception")
+        print(exception)
+
+
+custom_exception_handlers = {
+    "lambda": lambda exception: print("Custom Exception Handler", exception),
+    "wrong_argspec": custom_exception_handler_with_wrong_argspec,
+    "wrong_arg_order": custom_exception_handler_with_wrong_arg_order,
+    "valid": valid_custom_handler,
+    "partial": functools.partial(valid_custom_handler, logger="test"),
+    "method": DummyExceptionHandler().handle,
+}
+
+
+@pytest.mark.parametrize(
+    "handler_fn, expected",
+    [
+        pytest.param(
+            custom_exception_handlers["partial"],
+            pytest.raises(ValueError),
+            id="partial",
+        ),
+        pytest.param(
+            custom_exception_handlers["lambda"],
+            pytest.raises(ValueError),
+            id="lambda",
+        ),
+        pytest.param(
+            custom_exception_handlers["wrong_argspec"],
+            pytest.raises(ValueError),
+            id="wrong_argspec",
+        ),
+        pytest.param(
+            custom_exception_handlers["wrong_arg_order"],
+            pytest.raises(ValueError),
+            id="wrong_arg_order",
+        ),
+        pytest.param(
+            custom_exception_handlers["valid"],
+            does_not_raise(),
+            id="valid_handler",
+        ),
+        pytest.param(
+            custom_exception_handlers["method"],
+            does_not_raise(),
+            id="valid_class_method",
+        ),
+    ],
+)
+def test_frontend_exception_handler_validation(handler_fn, expected):
+    """Test that the custom frontend exception handler is properly validated.
+
+    Args:
+        handler_fn: The handler function.
+        expected: The expected result.
+
+    """
+    with expected:
+        rx.App(frontend_exception_handler=handler_fn)._validate_exception_handlers()
+
+
+def backend_exception_handler_with_wrong_return_type(exception: Exception) -> int:
+    """Custom backend exception handler with wrong return type.
+
+    Args:
+        exception: The exception.
+
+    Returns:
+        int: The wrong return type.
+
+    """
+    print("Custom Backend Exception")
+    print(exception)
+
+    return 5
+
+
+@pytest.mark.parametrize(
+    "handler_fn, expected",
+    [
+        pytest.param(
+            backend_exception_handler_with_wrong_return_type,
+            pytest.raises(ValueError),
+            id="wrong_return_type",
+        ),
+        pytest.param(
+            custom_exception_handlers["partial"],
+            pytest.raises(ValueError),
+            id="partial",
+        ),
+        pytest.param(
+            custom_exception_handlers["lambda"],
+            pytest.raises(ValueError),
+            id="lambda",
+        ),
+        pytest.param(
+            custom_exception_handlers["wrong_argspec"],
+            pytest.raises(ValueError),
+            id="wrong_argspec",
+        ),
+        pytest.param(
+            custom_exception_handlers["wrong_arg_order"],
+            pytest.raises(ValueError),
+            id="wrong_arg_order",
+        ),
+        pytest.param(
+            custom_exception_handlers["valid"],
+            does_not_raise(),
+            id="valid_handler",
+        ),
+        pytest.param(
+            custom_exception_handlers["method"],
+            does_not_raise(),
+            id="valid_class_method",
+        ),
+    ],
+)
+def test_backend_exception_handler_validation(handler_fn, expected):
+    """Test that the custom backend exception handler is properly validated.
+
+    Args:
+        handler_fn: The handler function.
+        expected: The expected result.
+
+    """
+    with expected:
+        rx.App(backend_exception_handler=handler_fn)._validate_exception_handlers()

+ 3 - 1
tests/test_state.py

@@ -1464,11 +1464,13 @@ def test_error_on_state_method_shadow():
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_state_with_invalid_yield(capsys):
+async def test_state_with_invalid_yield(capsys, mock_app):
     """Test that an error is thrown when a state yields an invalid value.
     """Test that an error is thrown when a state yields an invalid value.
 
 
     Args:
     Args:
         capsys: Pytest fixture for capture standard streams.
         capsys: Pytest fixture for capture standard streams.
+        mock_app: Mock app fixture.
+
     """
     """
 
 
     class StateWithInvalidYield(BaseState):
     class StateWithInvalidYield(BaseState):