소스 검색

[REF-201] Separate on_load handler from initial hydration (#1847)

Masen Furer 1 년 전
부모
커밋
60147dec65

+ 5 - 0
reflex/.templates/jinja/web/utils/context.js.jinja2

@@ -23,10 +23,15 @@ export const clientStorage = {}
 {% endif %}
 
 {% if state_name %}
+export const onLoadInternalEvent = () => [Event('{{state_name}}.{{const.on_load_internal}}')]
+
 export const initialEvents = () => [
     Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)),
+    ...onLoadInternalEvent()
 ]
 {% else %}
+export const onLoadInternalEvent = () => []
+
 export const initialEvents = () => []
 {% endif %}
 

+ 9 - 4
reflex/.templates/web/utils/state.js

@@ -6,7 +6,7 @@ import env from "env.json";
 import Cookies from "universal-cookie";
 import { useEffect, useReducer, useRef, useState } from "react";
 import Router, { useRouter } from "next/router";
-import { initialEvents, initialState } from "utils/context.js"
+import { initialEvents, initialState, onLoadInternalEvent } from "utils/context.js"
 
 // Endpoint URLs.
 const EVENTURL = env.EVENT
@@ -529,10 +529,15 @@ export const useEventLoop = (
   }
 
   const sentHydrate = useRef(false);  // Avoid double-hydrate due to React strict-mode
-  // initial state hydrate
   useEffect(() => {
     if (router.isReady && !sentHydrate.current) {
-      addEvents(initial_events())
+    const events = initial_events()
+      addEvents(events.map((e) => (
+        {
+          ...e,
+          router_data: (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(router)
+        }
+      )))
       sentHydrate.current = true
     }
   }, [router.isReady])
@@ -560,7 +565,7 @@ export const useEventLoop = (
 
   // Route after the initial page hydration.
   useEffect(() => {
-    const change_complete = () => addEvents(initial_events())
+    const change_complete = () => addEvents(onLoadInternalEvent())
     router.events.on('routeChangeComplete', change_complete)
     return () => {
       router.events.off('routeChangeComplete', change_complete)

+ 1 - 1
reflex/app.pyi

@@ -125,7 +125,7 @@ class App(Base):
         self, state: State, event: Event
     ) -> asyncio.Task | None: ...
 
-async def process(
+def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
 ) -> AsyncIterator[StateUpdate]: ...
 async def ping() -> str: ...

+ 1 - 0
reflex/compiler/templates.py

@@ -40,6 +40,7 @@ class ReflexJinjaEnvironment(Environment):
             "toggle_color_mode": constants.ColorMode.TOGGLE,
             "use_color_mode": constants.ColorMode.USE,
             "hydrate": constants.CompileVars.HYDRATE,
+            "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
         }
 
 

+ 2 - 0
reflex/constants/__init__.py

@@ -48,6 +48,7 @@ from .route import (
     ROUTE_NOT_FOUND,
     ROUTER,
     ROUTER_DATA,
+    ROUTER_DATA_INCLUDE,
     DefaultPage,
     Page404,
     RouteArgType,
@@ -97,6 +98,7 @@ __ALL__ = [
     RouteVar,
     ROUTER,
     ROUTER_DATA,
+    ROUTER_DATA_INCLUDE,
     ROUTE_NOT_FOUND,
     SETTER_PREFIX,
     SKIP_COMPILE_ENV_VAR,

+ 2 - 0
reflex/constants/compiler.py

@@ -58,6 +58,8 @@ class CompileVars(SimpleNamespace):
     CONNECT_ERROR = "connectError"
     # The name of the function for converting a dict to an event.
     TO_EVENT = "Event"
+    # The name of the internal on_load event.
+    ON_LOAD_INTERNAL = "on_load_internal"
 
 
 class PageNames(SimpleNamespace):

+ 4 - 0
reflex/constants/route.py

@@ -30,6 +30,10 @@ class RouteVar(SimpleNamespace):
     COOKIE = "cookie"
 
 
+# This subset of router_data is included in chained on_load events.
+ROUTER_DATA_INCLUDE = set((RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY))
+
+
 class RouteRegex(SimpleNamespace):
     """Regex used for extracting route args in route."""
 

+ 11 - 1
reflex/event.py

@@ -826,6 +826,10 @@ def fix_events(
     # Fix the events created by the handler.
     out = []
     for e in events:
+        if isinstance(e, Event):
+            # If the event is already an event, append it to the list.
+            out.append(e)
+            continue
         if not isinstance(e, (EventHandler, EventSpec)):
             e = EventHandler(fn=e)
         # Otherwise, create an event from the event spec.
@@ -835,13 +839,19 @@ def fix_events(
         name = format.format_event_handler(e.handler)
         payload = {k._var_name: v._decode() for k, v in e.args}  # type: ignore
 
+        # Filter router_data to reduce payload size
+        event_router_data = {
+            k: v
+            for k, v in (router_data or {}).items()
+            if k in constants.route.ROUTER_DATA_INCLUDE
+        }
         # Create an event and append it to the list.
         out.append(
             Event(
                 token=token,
                 name=name,
                 payload=payload,
-                router_data=router_data or {},
+                router_data=event_router_data,
             )
         )
 

+ 2 - 8
reflex/middleware/hydrate_middleware.py

@@ -4,7 +4,7 @@ from __future__ import annotations
 from typing import TYPE_CHECKING, Optional
 
 from reflex import constants
-from reflex.event import Event, fix_events, get_hydrate_event
+from reflex.event import Event, get_hydrate_event
 from reflex.middleware.middleware import Middleware
 from reflex.state import BaseState, StateUpdate
 from reflex.utils import format
@@ -52,11 +52,5 @@ class HydrateMiddleware(Middleware):
         # since a full dict was captured, clean any dirtiness
         state._clean()
 
-        # Get the route for on_load events.
-        route = event.router_data.get(constants.RouteVar.PATH, "")
-        # Add the on_load events and set is_hydrated to True.
-        events = [*app.get_load_events(route), type(state).set_is_hydrated(True)]  # type: ignore
-        events = fix_events(events, event.token, router_data=event.router_data)
-
         # Return the state update.
-        return StateUpdate(delta=delta, events=events)
+        return StateUpdate(delta=delta, events=[])

+ 21 - 1
reflex/state.py

@@ -1016,7 +1016,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
 
         def _is_valid_type(events: Any) -> bool:
-            return isinstance(events, (EventHandler, EventSpec))
+            return isinstance(events, (Event, EventHandler, EventSpec))
 
         if events is None or _is_valid_type(events):
             return events
@@ -1313,6 +1313,26 @@ class State(BaseState):
     # The hydrated bool.
     is_hydrated: bool = False
 
+    def on_load_internal(self) -> list[Event | EventSpec] | None:
+        """Queue on_load handlers for the current page.
+
+        Returns:
+            The list of events to queue for on load handling.
+        """
+        app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+        load_events = app.get_load_events(self.router.page.path)
+        if not load_events and self.is_hydrated:
+            return  # Fast path for page-to-page navigation
+        self.is_hydrated = False
+        return [
+            *fix_events(
+                load_events,
+                self.router.session.client_token,
+                router_data=self.router_data,
+            ),
+            type(self).set_is_hydrated(True),  # type: ignore
+        ]
+
 
 class StateProxy(wrapt.ObjectProxy):
     """Proxy of a state instance to control mutability of vars for a background task.

+ 8 - 0
reflex/utils/prerequisites.py

@@ -123,9 +123,17 @@ def get_app(reload: bool = False) -> ModuleType:
 
     Returns:
         The app based on the default config.
+
+    Raises:
+        RuntimeError: If the app name is not set in the config.
     """
     os.environ[constants.RELOAD_CONFIG] = str(reload)
     config = get_config()
+    if not config.app_name:
+        raise RuntimeError(
+            "Cannot get the app module because `app_name` is not set in rxconfig! "
+            "If this error occurs in a reflex test case, ensure that `get_app` is mocked."
+        )
     module = ".".join([config.app_name, config.app_name])
     sys.path.insert(0, os.getcwd())
     app = __import__(module, fromlist=(constants.CompileVars.APP,))

+ 22 - 0
tests/conftest.py

@@ -5,11 +5,13 @@ import platform
 import uuid
 from pathlib import Path
 from typing import Dict, Generator
+from unittest import mock
 
 import pytest
 
 from reflex.app import App
 from reflex.event import EventSpec
+from reflex.utils import prerequisites
 
 from .states import (
     DictMutationTestState,
@@ -30,6 +32,26 @@ def app() -> App:
     return App()
 
 
+@pytest.fixture
+def app_module_mock(monkeypatch) -> mock.Mock:
+    """Mock the app module.
+
+    This overwrites prerequisites.get_app to return the mock for the app module.
+
+    To use this in your test, assign `app_module_mock.app = rx.App(...)`.
+
+    Args:
+        monkeypatch: pytest monkeypatch fixture.
+
+    Returns:
+        The mock for the main app module.
+    """
+    app_module_mock = mock.Mock()
+    get_app_mock = mock.Mock(return_value=app_module_mock)
+    monkeypatch.setattr(prerequisites, "get_app", get_app_mock)
+    return app_module_mock
+
+
 @pytest.fixture(scope="session")
 def windows_platform() -> Generator:
     """Check if system is windows.

+ 1 - 11
tests/middleware/conftest.py

@@ -21,14 +21,4 @@ def create_event(name):
 
 @pytest.fixture
 def event1():
-    return create_event("test_state.hydrate")
-
-
-@pytest.fixture
-def event2():
-    return create_event("test_state2.hydrate")
-
-
-@pytest.fixture
-def event3():
-    return create_event("test_state3.hydrate")
+    return create_event("state.hydrate")

+ 9 - 134
tests/middleware/test_hydrate_middleware.py

@@ -1,27 +1,13 @@
-from typing import Any, Dict
+from __future__ import annotations
 
 import pytest
 
-from reflex import constants
 from reflex.app import App
-from reflex.constants import CompileVars
 from reflex.middleware.hydrate_middleware import HydrateMiddleware
-from reflex.state import BaseState, StateUpdate
+from reflex.state import State, StateUpdate
 
 
-def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
-    """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
-
-    Args:
-        state: the State that is hydrated
-
-    Returns:
-        dict similar to that returned by `State.get_delta` with IS_HYDRATED: True
-    """
-    return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
-
-
-class TestState(BaseState):
+class TestState(State):
     """A test state with no return in handler."""
 
     __test__ = False
@@ -33,40 +19,6 @@ class TestState(BaseState):
         self.num += 1
 
 
-class TestState2(BaseState):
-    """A test state with return in handler."""
-
-    __test__ = False
-
-    num: int = 0
-    name: str
-
-    def test_handler(self):
-        """Test handler that calls another handler.
-
-        Returns:
-            Chain of EventHandlers
-        """
-        self.num += 1
-        return self.change_name
-
-    def change_name(self):
-        """Test handler to change name."""
-        self.name = "random"
-
-
-class TestState3(BaseState):
-    """A test state with async handler."""
-
-    __test__ = False
-
-    num: int = 0
-
-    async def test_handler(self):
-        """Test handler."""
-        self.num += 1
-
-
 @pytest.fixture
 def hydrate_middleware() -> HydrateMiddleware:
     """Fixture creates an instance of HydrateMiddleware per test case.
@@ -78,98 +30,21 @@ def hydrate_middleware() -> HydrateMiddleware:
 
 
 @pytest.mark.asyncio
-@pytest.mark.parametrize(
-    "test_state, expected, event_fixture",
-    [
-        (TestState, {"test_state": {"num": 1}}, "event1"),
-        (TestState2, {"test_state2": {"num": 1}}, "event2"),
-        (TestState3, {"test_state3": {"num": 1}}, "event3"),
-    ],
-)
-async def test_preprocess(
-    test_state, hydrate_middleware, request, event_fixture, expected
-):
-    """Test that a state hydrate event is processed correctly.
-
-    Args:
-        test_state: State to process event.
-        hydrate_middleware: Instance of HydrateMiddleware.
-        request: Pytest fixture request.
-        event_fixture: The event fixture(an Event).
-        expected: Expected delta.
-    """
-    test_state.add_var(
-        constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
-    )
-    app = App(state=test_state, load_events={"index": [test_state.test_handler]})
-    state = test_state()
-
-    update = await hydrate_middleware.preprocess(
-        app=app, event=request.getfixturevalue(event_fixture), state=state
-    )
-    assert isinstance(update, StateUpdate)
-    assert update.delta == state.dict()
-    events = update.events
-    assert len(events) == 2
-
-    # Apply the on_load event.
-    update = await state._process(events[0]).__anext__()
-    assert update.delta == expected
-
-    # Apply the hydrate event.
-    update = await state._process(events[1]).__anext__()
-    assert update.delta == exp_is_hydrated(state)
-
-
-@pytest.mark.asyncio
-async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
-    """Test that a state hydrate event for multiple on-load events is processed correctly.
-
-    Args:
-        hydrate_middleware: Instance of HydrateMiddleware
-        event1: An Event.
-    """
-    app = App(
-        state=TestState,
-        load_events={"index": [TestState.test_handler, TestState.test_handler]},
-    )
-    state = TestState()
-
-    update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
-    assert isinstance(update, StateUpdate)
-    assert update.delta == state.dict()
-    assert len(update.events) == 3
-
-    # Apply the events.
-    events = update.events
-    update = await state._process(events[0]).__anext__()
-    assert update.delta == {"test_state": {"num": 1}}
-
-    update = await state._process(events[1]).__anext__()
-    assert update.delta == {"test_state": {"num": 2}}
-
-    update = await state._process(events[2]).__anext__()
-    assert update.delta == exp_is_hydrated(state)
-
-
-@pytest.mark.asyncio
-async def test_preprocess_no_events(hydrate_middleware, event1):
+async def test_preprocess_no_events(hydrate_middleware, event1, mocker):
     """Test that app without on_load is processed correctly.
 
     Args:
         hydrate_middleware: Instance of HydrateMiddleware
         event1: An Event.
+        mocker: pytest mock object.
     """
-    state = TestState()
+    mocker.patch("reflex.state.State.class_subclasses", {TestState})
+    state = State()
     update = await hydrate_middleware.preprocess(
-        app=App(state=TestState),
+        app=App(state=State),
         event=event1,
         state=state,
     )
     assert isinstance(update, StateUpdate)
     assert update.delta == state.dict()
-    assert len(update.events) == 1
-    assert isinstance(update, StateUpdate)
-
-    update = await state._process(update.events[0]).__anext__()
-    assert update.delta == exp_is_hydrated(state)
+    assert not update.events

+ 28 - 33
tests/test_app.py

@@ -25,10 +25,10 @@ from reflex.app import (
     upload,
 )
 from reflex.components import Box, Component, Cond, Fragment, Text
-from reflex.event import Event, get_hydrate_event
+from reflex.event import Event
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
-from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
+from reflex.state import BaseState, State, StateManagerRedis, StateUpdate
 from reflex.style import Style
 from reflex.utils import format
 from reflex.vars import ComputedVar
@@ -870,6 +870,7 @@ class DynamicState(BaseState):
         recalculated when the dynamic route var was dirty
     """
 
+    is_hydrated: bool = False
     loaded: int = 0
     counter: int = 0
 
@@ -893,10 +894,16 @@ class DynamicState(BaseState):
         # self.side_effect_counter = self.side_effect_counter + 1
         return self.dynamic
 
+    on_load_internal = State.on_load_internal.fn
+
 
 @pytest.mark.asyncio
 async def test_dynamic_route_var_route_change_completed_on_load(
-    index_page, windows_platform: bool, token: str, mocker
+    index_page,
+    windows_platform: bool,
+    token: str,
+    app_module_mock: unittest.mock.Mock,
+    mocker,
 ):
     """Create app with dynamic route var, and simulate navigation.
 
@@ -907,17 +914,14 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         index_page: The index page.
         windows_platform: Whether the system is windows.
         token: a Token.
+        app_module_mock: Mocked app module.
         mocker: pytest mocker object.
     """
-    mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
-    DynamicState.add_var(
-        constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
-    )
     arg_name = "dynamic"
     route = f"/test/[{arg_name}]"
     if windows_platform:
         route.lstrip("/").replace("/", "\\")
-    app = App(state=DynamicState)
+    app = app_module_mock.app = App(state=DynamicState)
     assert arg_name not in app.state.vars
     app.add_page(index_page, route=route, on_load=DynamicState.on_load)  # type: ignore
     assert arg_name in app.state.vars
@@ -953,33 +957,25 @@ async def test_dynamic_route_var_route_change_completed_on_load(
 
     prev_exp_val = ""
     for exp_index, exp_val in enumerate(exp_vals):
-        hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
-        exp_router_data = {
-            "headers": {},
-            "ip": client_ip,
-            "sid": sid,
-            "token": token,
-            **hydrate_event.router_data,
-        }
-        exp_router = RouterData(exp_router_data)
+        on_load_internal = _event(
+            name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL}",
+            val=exp_val,
+        )
         process_coro = process(
             app,
-            event=hydrate_event,
+            event=on_load_internal,
             sid=sid,
             headers={},
             client_ip=client_ip,
         )
-        update = await process_coro.__anext__()  # type: ignore
-        # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
+        update = await process_coro.__anext__()
+        # route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)]
         assert update == StateUpdate(
             delta={
                 state.get_name(): {
                     arg_name: exp_val,
                     f"comp_{arg_name}": exp_val,
                     constants.CompileVars.IS_HYDRATED: False,
-                    "loaded": exp_index,
-                    "counter": exp_index,
-                    "router": exp_router,
                     # "side_effect_counter": exp_index,
                 }
             },
@@ -987,13 +983,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                 _dynamic_state_event(
                     name="on_load",
                     val=exp_val,
-                    router_data=exp_router_data,
                 ),
                 _dynamic_state_event(
                     name="set_is_hydrated",
                     payload={"value": True},
                     val=exp_val,
-                    router_data=exp_router_data,
+                    router_data={},
                 ),
             ],
         )
@@ -1004,7 +999,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
 
         # complete the processing
         with pytest.raises(StopAsyncIteration):
-            await process_coro.__anext__()  # type: ignore
+            await process_coro.__anext__()
 
         # check that router data was written to the state_manager store
         state = await app.state_manager.get_state(token)
@@ -1017,7 +1012,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             headers={},
             client_ip=client_ip,
         )
-        on_load_update = await process_coro.__anext__()  # type: ignore
+        on_load_update = await process_coro.__anext__()
         assert on_load_update == StateUpdate(
             delta={
                 state.get_name(): {
@@ -1031,7 +1026,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         # complete the processing
         with pytest.raises(StopAsyncIteration):
-            await process_coro.__anext__()  # type: ignore
+            await process_coro.__anext__()
         process_coro = process(
             app,
             event=_dynamic_state_event(
@@ -1041,7 +1036,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             headers={},
             client_ip=client_ip,
         )
-        on_set_is_hydrated_update = await process_coro.__anext__()  # type: ignore
+        on_set_is_hydrated_update = await process_coro.__anext__()
         assert on_set_is_hydrated_update == StateUpdate(
             delta={
                 state.get_name(): {
@@ -1055,7 +1050,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         # complete the processing
         with pytest.raises(StopAsyncIteration):
-            await process_coro.__anext__()  # type: ignore
+            await process_coro.__anext__()
 
         # a simple state update event should NOT trigger on_load or route var side effects
         process_coro = process(
@@ -1065,7 +1060,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             headers={},
             client_ip=client_ip,
         )
-        update = await process_coro.__anext__()  # type: ignore
+        update = await process_coro.__anext__()
         assert update == StateUpdate(
             delta={
                 state.get_name(): {
@@ -1079,7 +1074,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         # complete the processing
         with pytest.raises(StopAsyncIteration):
-            await process_coro.__anext__()  # type: ignore
+            await process_coro.__anext__()
 
         prev_exp_val = exp_val
     state = await app.state_manager.get_state(token)
@@ -1116,7 +1111,7 @@ async def test_process_events(mocker, token: str):
         token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
     )
 
-    async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):  # type: ignore
+    async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
         pass
 
     assert (await app.state_manager.get_state(token)).value == 5

+ 165 - 11
tests/test_state.py

@@ -7,7 +7,7 @@ import functools
 import json
 import os
 import sys
-from typing import Dict, Generator, List, Optional, Union
+from typing import Any, Dict, Generator, List, Optional, Union
 from unittest.mock import AsyncMock, Mock
 
 import pytest
@@ -24,6 +24,7 @@ from reflex.state import (
     LockExpiredError,
     MutableProxy,
     RouterData,
+    State,
     StateManager,
     StateManagerMemory,
     StateManagerRedis,
@@ -1374,8 +1375,13 @@ def test_error_on_state_method_shadow():
     )
 
 
-def test_state_with_invalid_yield():
-    """Test that an error is thrown when a state yields an invalid value."""
+@pytest.mark.asyncio
+async def test_state_with_invalid_yield(capsys):
+    """Test that an error is thrown when a state yields an invalid value.
+
+    Args:
+        capsys: Pytest fixture for capture standard streams.
+    """
 
     class StateWithInvalidYield(BaseState):
         """A state that yields an invalid value."""
@@ -1389,15 +1395,16 @@ def test_state_with_invalid_yield():
             yield 1
 
     invalid_state = StateWithInvalidYield()
-    with pytest.raises(TypeError) as err:
-        invalid_state._check_valid(
-            invalid_state.event_handlers["invalid_handler"],
-            rx.event.Event(token="fake_token", name="invalid_handler"),
+    async for update in invalid_state._process(
+        rx.event.Event(token="fake_token", name="invalid_handler")
+    ):
+        assert not update.delta
+        assert update.events == rx.event.fix_events(
+            [rx.window_alert("An error occurred. See logs for details.")],
+            token="",
         )
-    assert (
-        "must only return/yield: None, Events or other EventHandlers"
-        in err.value.args[0]
-    )
+    captured = capsys.readouterr()
+    assert "must only return/yield: None, Events or other EventHandlers" in captured.out
 
 
 @pytest.fixture(scope="function", params=["in_process", "redis"])
@@ -2303,3 +2310,150 @@ def test_state_union_optional():
     assert UnionState.custom_union.c2r is not None  # type: ignore
     assert types.is_optional(UnionState.opt_int._var_type)  # type: ignore
     assert types.is_union(UnionState.int_float._var_type)  # type: ignore
+
+
+def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
+    """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
+
+    Args:
+        state: the State that is hydrated.
+        is_hydrated: whether the state is hydrated.
+
+    Returns:
+        dict similar to that returned by `State.get_delta` with IS_HYDRATED: is_hydrated
+    """
+    return {state.get_full_name(): {CompileVars.IS_HYDRATED: is_hydrated}}
+
+
+class OnLoadState(State):
+    """A test state with no return in handler."""
+
+    num: int = 0
+
+    def test_handler(self):
+        """Test handler."""
+        self.num += 1
+
+
+class OnLoadState2(State):
+    """A test state with return in handler."""
+
+    num: int = 0
+    name: str
+
+    def test_handler(self):
+        """Test handler that calls another handler.
+
+        Returns:
+            Chain of EventHandlers
+        """
+        self.num += 1
+        return self.change_name
+
+    def change_name(self):
+        """Test handler to change name."""
+        self.name = "random"
+
+
+class OnLoadState3(State):
+    """A test state with async handler."""
+
+    num: int = 0
+
+    async def test_handler(self):
+        """Test handler."""
+        self.num += 1
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "test_state, expected",
+    [
+        (OnLoadState, {"on_load_state": {"num": 1}}),
+        (OnLoadState2, {"on_load_state2": {"num": 1}}),
+        (OnLoadState3, {"on_load_state3": {"num": 1}}),
+    ],
+)
+async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
+    """Test that a state hydrate event is processed correctly.
+
+    Args:
+        app_module_mock: The app module that will be returned by get_app().
+        token: A token.
+        test_state: State to process event.
+        expected: Expected delta.
+        mocker: pytest mock object.
+    """
+    mocker.patch("reflex.state.State.class_subclasses", {test_state})
+    app = app_module_mock.app = App(
+        state=State, load_events={"index": [test_state.test_handler]}
+    )
+    state = State()
+
+    updates = []
+    async for update in rx.app.process(
+        app=app,
+        event=Event(
+            token=token,
+            name=f"{state.get_name()}.{CompileVars.ON_LOAD_INTERNAL}",
+            router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
+        ),
+        sid="sid",
+        headers={},
+        client_ip="",
+    ):
+        assert isinstance(update, StateUpdate)
+        updates.append(update)
+    assert len(updates) == 1
+    assert updates[0].delta == exp_is_hydrated(state, False)
+
+    events = updates[0].events
+    assert len(events) == 2
+    assert (await state._process(events[0]).__anext__()).delta == {
+        test_state.get_full_name(): {"num": 1}
+    }
+    assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state)
+
+
+@pytest.mark.asyncio
+async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
+    """Test that a state hydrate event for multiple on-load events is processed correctly.
+
+    Args:
+        app_module_mock: The app module that will be returned by get_app().
+        token: A token.
+        mocker: pytest mock object.
+    """
+    mocker.patch("reflex.state.State.class_subclasses", {OnLoadState})
+    app = app_module_mock.app = App(
+        state=State,
+        load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
+    )
+    state = State()
+
+    updates = []
+    async for update in rx.app.process(
+        app=app,
+        event=Event(
+            token=token,
+            name=f"{state.get_full_name()}.{CompileVars.ON_LOAD_INTERNAL}",
+            router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
+        ),
+        sid="sid",
+        headers={},
+        client_ip="",
+    ):
+        assert isinstance(update, StateUpdate)
+        updates.append(update)
+    assert len(updates) == 1
+    assert updates[0].delta == exp_is_hydrated(state, False)
+
+    events = updates[0].events
+    assert len(events) == 3
+    assert (await state._process(events[0]).__anext__()).delta == {
+        OnLoadState.get_full_name(): {"num": 1}
+    }
+    assert (await state._process(events[1]).__anext__()).delta == {
+        OnLoadState.get_full_name(): {"num": 2}
+    }
+    assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)