Bladeren bron

[REF-201] Separate on_load handler from initial hydration (#1847)

Masen Furer 1 jaar geleden
bovenliggende
commit
60147dec65

+ 5 - 0
reflex/.templates/jinja/web/utils/context.js.jinja2

@@ -23,10 +23,15 @@ export const clientStorage = {}
 {% endif %}
 {% endif %}
 
 
 {% if state_name %}
 {% if state_name %}
+export const onLoadInternalEvent = () => [Event('{{state_name}}.{{const.on_load_internal}}')]
+
 export const initialEvents = () => [
 export const initialEvents = () => [
     Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)),
     Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)),
+    ...onLoadInternalEvent()
 ]
 ]
 {% else %}
 {% else %}
+export const onLoadInternalEvent = () => []
+
 export const initialEvents = () => []
 export const initialEvents = () => []
 {% endif %}
 {% endif %}
 
 

+ 9 - 4
reflex/.templates/web/utils/state.js

@@ -6,7 +6,7 @@ import env from "env.json";
 import Cookies from "universal-cookie";
 import Cookies from "universal-cookie";
 import { useEffect, useReducer, useRef, useState } from "react";
 import { useEffect, useReducer, useRef, useState } from "react";
 import Router, { useRouter } from "next/router";
 import Router, { useRouter } from "next/router";
-import { initialEvents, initialState } from "utils/context.js"
+import { initialEvents, initialState, onLoadInternalEvent } from "utils/context.js"
 
 
 // Endpoint URLs.
 // Endpoint URLs.
 const EVENTURL = env.EVENT
 const EVENTURL = env.EVENT
@@ -529,10 +529,15 @@ export const useEventLoop = (
   }
   }
 
 
   const sentHydrate = useRef(false);  // Avoid double-hydrate due to React strict-mode
   const sentHydrate = useRef(false);  // Avoid double-hydrate due to React strict-mode
-  // initial state hydrate
   useEffect(() => {
   useEffect(() => {
     if (router.isReady && !sentHydrate.current) {
     if (router.isReady && !sentHydrate.current) {
-      addEvents(initial_events())
+    const events = initial_events()
+      addEvents(events.map((e) => (
+        {
+          ...e,
+          router_data: (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(router)
+        }
+      )))
       sentHydrate.current = true
       sentHydrate.current = true
     }
     }
   }, [router.isReady])
   }, [router.isReady])
@@ -560,7 +565,7 @@ export const useEventLoop = (
 
 
   // Route after the initial page hydration.
   // Route after the initial page hydration.
   useEffect(() => {
   useEffect(() => {
-    const change_complete = () => addEvents(initial_events())
+    const change_complete = () => addEvents(onLoadInternalEvent())
     router.events.on('routeChangeComplete', change_complete)
     router.events.on('routeChangeComplete', change_complete)
     return () => {
     return () => {
       router.events.off('routeChangeComplete', change_complete)
       router.events.off('routeChangeComplete', change_complete)

+ 1 - 1
reflex/app.pyi

@@ -125,7 +125,7 @@ class App(Base):
         self, state: State, event: Event
         self, state: State, event: Event
     ) -> asyncio.Task | None: ...
     ) -> asyncio.Task | None: ...
 
 
-async def process(
+def process(
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
     app: App, event: Event, sid: str, headers: Dict, client_ip: str
 ) -> AsyncIterator[StateUpdate]: ...
 ) -> AsyncIterator[StateUpdate]: ...
 async def ping() -> str: ...
 async def ping() -> str: ...

+ 1 - 0
reflex/compiler/templates.py

@@ -40,6 +40,7 @@ class ReflexJinjaEnvironment(Environment):
             "toggle_color_mode": constants.ColorMode.TOGGLE,
             "toggle_color_mode": constants.ColorMode.TOGGLE,
             "use_color_mode": constants.ColorMode.USE,
             "use_color_mode": constants.ColorMode.USE,
             "hydrate": constants.CompileVars.HYDRATE,
             "hydrate": constants.CompileVars.HYDRATE,
+            "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
         }
         }
 
 
 
 

+ 2 - 0
reflex/constants/__init__.py

@@ -48,6 +48,7 @@ from .route import (
     ROUTE_NOT_FOUND,
     ROUTE_NOT_FOUND,
     ROUTER,
     ROUTER,
     ROUTER_DATA,
     ROUTER_DATA,
+    ROUTER_DATA_INCLUDE,
     DefaultPage,
     DefaultPage,
     Page404,
     Page404,
     RouteArgType,
     RouteArgType,
@@ -97,6 +98,7 @@ __ALL__ = [
     RouteVar,
     RouteVar,
     ROUTER,
     ROUTER,
     ROUTER_DATA,
     ROUTER_DATA,
+    ROUTER_DATA_INCLUDE,
     ROUTE_NOT_FOUND,
     ROUTE_NOT_FOUND,
     SETTER_PREFIX,
     SETTER_PREFIX,
     SKIP_COMPILE_ENV_VAR,
     SKIP_COMPILE_ENV_VAR,

+ 2 - 0
reflex/constants/compiler.py

@@ -58,6 +58,8 @@ class CompileVars(SimpleNamespace):
     CONNECT_ERROR = "connectError"
     CONNECT_ERROR = "connectError"
     # The name of the function for converting a dict to an event.
     # The name of the function for converting a dict to an event.
     TO_EVENT = "Event"
     TO_EVENT = "Event"
+    # The name of the internal on_load event.
+    ON_LOAD_INTERNAL = "on_load_internal"
 
 
 
 
 class PageNames(SimpleNamespace):
 class PageNames(SimpleNamespace):

+ 4 - 0
reflex/constants/route.py

@@ -30,6 +30,10 @@ class RouteVar(SimpleNamespace):
     COOKIE = "cookie"
     COOKIE = "cookie"
 
 
 
 
+# This subset of router_data is included in chained on_load events.
+ROUTER_DATA_INCLUDE = set((RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY))
+
+
 class RouteRegex(SimpleNamespace):
 class RouteRegex(SimpleNamespace):
     """Regex used for extracting route args in route."""
     """Regex used for extracting route args in route."""
 
 

+ 11 - 1
reflex/event.py

@@ -826,6 +826,10 @@ def fix_events(
     # Fix the events created by the handler.
     # Fix the events created by the handler.
     out = []
     out = []
     for e in events:
     for e in events:
+        if isinstance(e, Event):
+            # If the event is already an event, append it to the list.
+            out.append(e)
+            continue
         if not isinstance(e, (EventHandler, EventSpec)):
         if not isinstance(e, (EventHandler, EventSpec)):
             e = EventHandler(fn=e)
             e = EventHandler(fn=e)
         # Otherwise, create an event from the event spec.
         # Otherwise, create an event from the event spec.
@@ -835,13 +839,19 @@ def fix_events(
         name = format.format_event_handler(e.handler)
         name = format.format_event_handler(e.handler)
         payload = {k._var_name: v._decode() for k, v in e.args}  # type: ignore
         payload = {k._var_name: v._decode() for k, v in e.args}  # type: ignore
 
 
+        # Filter router_data to reduce payload size
+        event_router_data = {
+            k: v
+            for k, v in (router_data or {}).items()
+            if k in constants.route.ROUTER_DATA_INCLUDE
+        }
         # Create an event and append it to the list.
         # Create an event and append it to the list.
         out.append(
         out.append(
             Event(
             Event(
                 token=token,
                 token=token,
                 name=name,
                 name=name,
                 payload=payload,
                 payload=payload,
-                router_data=router_data or {},
+                router_data=event_router_data,
             )
             )
         )
         )
 
 

+ 2 - 8
reflex/middleware/hydrate_middleware.py

@@ -4,7 +4,7 @@ from __future__ import annotations
 from typing import TYPE_CHECKING, Optional
 from typing import TYPE_CHECKING, Optional
 
 
 from reflex import constants
 from reflex import constants
-from reflex.event import Event, fix_events, get_hydrate_event
+from reflex.event import Event, get_hydrate_event
 from reflex.middleware.middleware import Middleware
 from reflex.middleware.middleware import Middleware
 from reflex.state import BaseState, StateUpdate
 from reflex.state import BaseState, StateUpdate
 from reflex.utils import format
 from reflex.utils import format
@@ -52,11 +52,5 @@ class HydrateMiddleware(Middleware):
         # since a full dict was captured, clean any dirtiness
         # since a full dict was captured, clean any dirtiness
         state._clean()
         state._clean()
 
 
-        # Get the route for on_load events.
-        route = event.router_data.get(constants.RouteVar.PATH, "")
-        # Add the on_load events and set is_hydrated to True.
-        events = [*app.get_load_events(route), type(state).set_is_hydrated(True)]  # type: ignore
-        events = fix_events(events, event.token, router_data=event.router_data)
-
         # Return the state update.
         # Return the state update.
-        return StateUpdate(delta=delta, events=events)
+        return StateUpdate(delta=delta, events=[])

+ 21 - 1
reflex/state.py

@@ -1016,7 +1016,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         """
         """
 
 
         def _is_valid_type(events: Any) -> bool:
         def _is_valid_type(events: Any) -> bool:
-            return isinstance(events, (EventHandler, EventSpec))
+            return isinstance(events, (Event, EventHandler, EventSpec))
 
 
         if events is None or _is_valid_type(events):
         if events is None or _is_valid_type(events):
             return events
             return events
@@ -1313,6 +1313,26 @@ class State(BaseState):
     # The hydrated bool.
     # The hydrated bool.
     is_hydrated: bool = False
     is_hydrated: bool = False
 
 
+    def on_load_internal(self) -> list[Event | EventSpec] | None:
+        """Queue on_load handlers for the current page.
+
+        Returns:
+            The list of events to queue for on load handling.
+        """
+        app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+        load_events = app.get_load_events(self.router.page.path)
+        if not load_events and self.is_hydrated:
+            return  # Fast path for page-to-page navigation
+        self.is_hydrated = False
+        return [
+            *fix_events(
+                load_events,
+                self.router.session.client_token,
+                router_data=self.router_data,
+            ),
+            type(self).set_is_hydrated(True),  # type: ignore
+        ]
+
 
 
 class StateProxy(wrapt.ObjectProxy):
 class StateProxy(wrapt.ObjectProxy):
     """Proxy of a state instance to control mutability of vars for a background task.
     """Proxy of a state instance to control mutability of vars for a background task.

+ 8 - 0
reflex/utils/prerequisites.py

@@ -123,9 +123,17 @@ def get_app(reload: bool = False) -> ModuleType:
 
 
     Returns:
     Returns:
         The app based on the default config.
         The app based on the default config.
+
+    Raises:
+        RuntimeError: If the app name is not set in the config.
     """
     """
     os.environ[constants.RELOAD_CONFIG] = str(reload)
     os.environ[constants.RELOAD_CONFIG] = str(reload)
     config = get_config()
     config = get_config()
+    if not config.app_name:
+        raise RuntimeError(
+            "Cannot get the app module because `app_name` is not set in rxconfig! "
+            "If this error occurs in a reflex test case, ensure that `get_app` is mocked."
+        )
     module = ".".join([config.app_name, config.app_name])
     module = ".".join([config.app_name, config.app_name])
     sys.path.insert(0, os.getcwd())
     sys.path.insert(0, os.getcwd())
     app = __import__(module, fromlist=(constants.CompileVars.APP,))
     app = __import__(module, fromlist=(constants.CompileVars.APP,))

+ 22 - 0
tests/conftest.py

@@ -5,11 +5,13 @@ import platform
 import uuid
 import uuid
 from pathlib import Path
 from pathlib import Path
 from typing import Dict, Generator
 from typing import Dict, Generator
+from unittest import mock
 
 
 import pytest
 import pytest
 
 
 from reflex.app import App
 from reflex.app import App
 from reflex.event import EventSpec
 from reflex.event import EventSpec
+from reflex.utils import prerequisites
 
 
 from .states import (
 from .states import (
     DictMutationTestState,
     DictMutationTestState,
@@ -30,6 +32,26 @@ def app() -> App:
     return App()
     return App()
 
 
 
 
+@pytest.fixture
+def app_module_mock(monkeypatch) -> mock.Mock:
+    """Mock the app module.
+
+    This overwrites prerequisites.get_app to return the mock for the app module.
+
+    To use this in your test, assign `app_module_mock.app = rx.App(...)`.
+
+    Args:
+        monkeypatch: pytest monkeypatch fixture.
+
+    Returns:
+        The mock for the main app module.
+    """
+    app_module_mock = mock.Mock()
+    get_app_mock = mock.Mock(return_value=app_module_mock)
+    monkeypatch.setattr(prerequisites, "get_app", get_app_mock)
+    return app_module_mock
+
+
 @pytest.fixture(scope="session")
 @pytest.fixture(scope="session")
 def windows_platform() -> Generator:
 def windows_platform() -> Generator:
     """Check if system is windows.
     """Check if system is windows.

+ 1 - 11
tests/middleware/conftest.py

@@ -21,14 +21,4 @@ def create_event(name):
 
 
 @pytest.fixture
 @pytest.fixture
 def event1():
 def event1():
-    return create_event("test_state.hydrate")
-
-
-@pytest.fixture
-def event2():
-    return create_event("test_state2.hydrate")
-
-
-@pytest.fixture
-def event3():
-    return create_event("test_state3.hydrate")
+    return create_event("state.hydrate")

+ 9 - 134
tests/middleware/test_hydrate_middleware.py

@@ -1,27 +1,13 @@
-from typing import Any, Dict
+from __future__ import annotations
 
 
 import pytest
 import pytest
 
 
-from reflex import constants
 from reflex.app import App
 from reflex.app import App
-from reflex.constants import CompileVars
 from reflex.middleware.hydrate_middleware import HydrateMiddleware
 from reflex.middleware.hydrate_middleware import HydrateMiddleware
-from reflex.state import BaseState, StateUpdate
+from reflex.state import State, StateUpdate
 
 
 
 
-def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
-    """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
-
-    Args:
-        state: the State that is hydrated
-
-    Returns:
-        dict similar to that returned by `State.get_delta` with IS_HYDRATED: True
-    """
-    return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
-
-
-class TestState(BaseState):
+class TestState(State):
     """A test state with no return in handler."""
     """A test state with no return in handler."""
 
 
     __test__ = False
     __test__ = False
@@ -33,40 +19,6 @@ class TestState(BaseState):
         self.num += 1
         self.num += 1
 
 
 
 
-class TestState2(BaseState):
-    """A test state with return in handler."""
-
-    __test__ = False
-
-    num: int = 0
-    name: str
-
-    def test_handler(self):
-        """Test handler that calls another handler.
-
-        Returns:
-            Chain of EventHandlers
-        """
-        self.num += 1
-        return self.change_name
-
-    def change_name(self):
-        """Test handler to change name."""
-        self.name = "random"
-
-
-class TestState3(BaseState):
-    """A test state with async handler."""
-
-    __test__ = False
-
-    num: int = 0
-
-    async def test_handler(self):
-        """Test handler."""
-        self.num += 1
-
-
 @pytest.fixture
 @pytest.fixture
 def hydrate_middleware() -> HydrateMiddleware:
 def hydrate_middleware() -> HydrateMiddleware:
     """Fixture creates an instance of HydrateMiddleware per test case.
     """Fixture creates an instance of HydrateMiddleware per test case.
@@ -78,98 +30,21 @@ def hydrate_middleware() -> HydrateMiddleware:
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-@pytest.mark.parametrize(
-    "test_state, expected, event_fixture",
-    [
-        (TestState, {"test_state": {"num": 1}}, "event1"),
-        (TestState2, {"test_state2": {"num": 1}}, "event2"),
-        (TestState3, {"test_state3": {"num": 1}}, "event3"),
-    ],
-)
-async def test_preprocess(
-    test_state, hydrate_middleware, request, event_fixture, expected
-):
-    """Test that a state hydrate event is processed correctly.
-
-    Args:
-        test_state: State to process event.
-        hydrate_middleware: Instance of HydrateMiddleware.
-        request: Pytest fixture request.
-        event_fixture: The event fixture(an Event).
-        expected: Expected delta.
-    """
-    test_state.add_var(
-        constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
-    )
-    app = App(state=test_state, load_events={"index": [test_state.test_handler]})
-    state = test_state()
-
-    update = await hydrate_middleware.preprocess(
-        app=app, event=request.getfixturevalue(event_fixture), state=state
-    )
-    assert isinstance(update, StateUpdate)
-    assert update.delta == state.dict()
-    events = update.events
-    assert len(events) == 2
-
-    # Apply the on_load event.
-    update = await state._process(events[0]).__anext__()
-    assert update.delta == expected
-
-    # Apply the hydrate event.
-    update = await state._process(events[1]).__anext__()
-    assert update.delta == exp_is_hydrated(state)
-
-
-@pytest.mark.asyncio
-async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
-    """Test that a state hydrate event for multiple on-load events is processed correctly.
-
-    Args:
-        hydrate_middleware: Instance of HydrateMiddleware
-        event1: An Event.
-    """
-    app = App(
-        state=TestState,
-        load_events={"index": [TestState.test_handler, TestState.test_handler]},
-    )
-    state = TestState()
-
-    update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
-    assert isinstance(update, StateUpdate)
-    assert update.delta == state.dict()
-    assert len(update.events) == 3
-
-    # Apply the events.
-    events = update.events
-    update = await state._process(events[0]).__anext__()
-    assert update.delta == {"test_state": {"num": 1}}
-
-    update = await state._process(events[1]).__anext__()
-    assert update.delta == {"test_state": {"num": 2}}
-
-    update = await state._process(events[2]).__anext__()
-    assert update.delta == exp_is_hydrated(state)
-
-
-@pytest.mark.asyncio
-async def test_preprocess_no_events(hydrate_middleware, event1):
+async def test_preprocess_no_events(hydrate_middleware, event1, mocker):
     """Test that app without on_load is processed correctly.
     """Test that app without on_load is processed correctly.
 
 
     Args:
     Args:
         hydrate_middleware: Instance of HydrateMiddleware
         hydrate_middleware: Instance of HydrateMiddleware
         event1: An Event.
         event1: An Event.
+        mocker: pytest mock object.
     """
     """
-    state = TestState()
+    mocker.patch("reflex.state.State.class_subclasses", {TestState})
+    state = State()
     update = await hydrate_middleware.preprocess(
     update = await hydrate_middleware.preprocess(
-        app=App(state=TestState),
+        app=App(state=State),
         event=event1,
         event=event1,
         state=state,
         state=state,
     )
     )
     assert isinstance(update, StateUpdate)
     assert isinstance(update, StateUpdate)
     assert update.delta == state.dict()
     assert update.delta == state.dict()
-    assert len(update.events) == 1
-    assert isinstance(update, StateUpdate)
-
-    update = await state._process(update.events[0]).__anext__()
-    assert update.delta == exp_is_hydrated(state)
+    assert not update.events

+ 28 - 33
tests/test_app.py

@@ -25,10 +25,10 @@ from reflex.app import (
     upload,
     upload,
 )
 )
 from reflex.components import Box, Component, Cond, Fragment, Text
 from reflex.components import Box, Component, Cond, Fragment, Text
-from reflex.event import Event, get_hydrate_event
+from reflex.event import Event
 from reflex.middleware import HydrateMiddleware
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
 from reflex.model import Model
-from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
+from reflex.state import BaseState, State, StateManagerRedis, StateUpdate
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import format
 from reflex.utils import format
 from reflex.vars import ComputedVar
 from reflex.vars import ComputedVar
@@ -870,6 +870,7 @@ class DynamicState(BaseState):
         recalculated when the dynamic route var was dirty
         recalculated when the dynamic route var was dirty
     """
     """
 
 
+    is_hydrated: bool = False
     loaded: int = 0
     loaded: int = 0
     counter: int = 0
     counter: int = 0
 
 
@@ -893,10 +894,16 @@ class DynamicState(BaseState):
         # self.side_effect_counter = self.side_effect_counter + 1
         # self.side_effect_counter = self.side_effect_counter + 1
         return self.dynamic
         return self.dynamic
 
 
+    on_load_internal = State.on_load_internal.fn
+
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dynamic_route_var_route_change_completed_on_load(
 async def test_dynamic_route_var_route_change_completed_on_load(
-    index_page, windows_platform: bool, token: str, mocker
+    index_page,
+    windows_platform: bool,
+    token: str,
+    app_module_mock: unittest.mock.Mock,
+    mocker,
 ):
 ):
     """Create app with dynamic route var, and simulate navigation.
     """Create app with dynamic route var, and simulate navigation.
 
 
@@ -907,17 +914,14 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         index_page: The index page.
         index_page: The index page.
         windows_platform: Whether the system is windows.
         windows_platform: Whether the system is windows.
         token: a Token.
         token: a Token.
+        app_module_mock: Mocked app module.
         mocker: pytest mocker object.
         mocker: pytest mocker object.
     """
     """
-    mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
-    DynamicState.add_var(
-        constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
-    )
     arg_name = "dynamic"
     arg_name = "dynamic"
     route = f"/test/[{arg_name}]"
     route = f"/test/[{arg_name}]"
     if windows_platform:
     if windows_platform:
         route.lstrip("/").replace("/", "\\")
         route.lstrip("/").replace("/", "\\")
-    app = App(state=DynamicState)
+    app = app_module_mock.app = App(state=DynamicState)
     assert arg_name not in app.state.vars
     assert arg_name not in app.state.vars
     app.add_page(index_page, route=route, on_load=DynamicState.on_load)  # type: ignore
     app.add_page(index_page, route=route, on_load=DynamicState.on_load)  # type: ignore
     assert arg_name in app.state.vars
     assert arg_name in app.state.vars
@@ -953,33 +957,25 @@ async def test_dynamic_route_var_route_change_completed_on_load(
 
 
     prev_exp_val = ""
     prev_exp_val = ""
     for exp_index, exp_val in enumerate(exp_vals):
     for exp_index, exp_val in enumerate(exp_vals):
-        hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
-        exp_router_data = {
-            "headers": {},
-            "ip": client_ip,
-            "sid": sid,
-            "token": token,
-            **hydrate_event.router_data,
-        }
-        exp_router = RouterData(exp_router_data)
+        on_load_internal = _event(
+            name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL}",
+            val=exp_val,
+        )
         process_coro = process(
         process_coro = process(
             app,
             app,
-            event=hydrate_event,
+            event=on_load_internal,
             sid=sid,
             sid=sid,
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
         )
         )
-        update = await process_coro.__anext__()  # type: ignore
-        # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
+        update = await process_coro.__anext__()
+        # route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)]
         assert update == StateUpdate(
         assert update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {
                     arg_name: exp_val,
                     arg_name: exp_val,
                     f"comp_{arg_name}": exp_val,
                     f"comp_{arg_name}": exp_val,
                     constants.CompileVars.IS_HYDRATED: False,
                     constants.CompileVars.IS_HYDRATED: False,
-                    "loaded": exp_index,
-                    "counter": exp_index,
-                    "router": exp_router,
                     # "side_effect_counter": exp_index,
                     # "side_effect_counter": exp_index,
                 }
                 }
             },
             },
@@ -987,13 +983,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                 _dynamic_state_event(
                 _dynamic_state_event(
                     name="on_load",
                     name="on_load",
                     val=exp_val,
                     val=exp_val,
-                    router_data=exp_router_data,
                 ),
                 ),
                 _dynamic_state_event(
                 _dynamic_state_event(
                     name="set_is_hydrated",
                     name="set_is_hydrated",
                     payload={"value": True},
                     payload={"value": True},
                     val=exp_val,
                     val=exp_val,
-                    router_data=exp_router_data,
+                    router_data={},
                 ),
                 ),
             ],
             ],
         )
         )
@@ -1004,7 +999,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
 
 
         # complete the processing
         # complete the processing
         with pytest.raises(StopAsyncIteration):
         with pytest.raises(StopAsyncIteration):
-            await process_coro.__anext__()  # type: ignore
+            await process_coro.__anext__()
 
 
         # check that router data was written to the state_manager store
         # check that router data was written to the state_manager store
         state = await app.state_manager.get_state(token)
         state = await app.state_manager.get_state(token)
@@ -1017,7 +1012,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
         )
         )
-        on_load_update = await process_coro.__anext__()  # type: ignore
+        on_load_update = await process_coro.__anext__()
         assert on_load_update == StateUpdate(
         assert on_load_update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {
@@ -1031,7 +1026,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         )
         # complete the processing
         # complete the processing
         with pytest.raises(StopAsyncIteration):
         with pytest.raises(StopAsyncIteration):
-            await process_coro.__anext__()  # type: ignore
+            await process_coro.__anext__()
         process_coro = process(
         process_coro = process(
             app,
             app,
             event=_dynamic_state_event(
             event=_dynamic_state_event(
@@ -1041,7 +1036,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
         )
         )
-        on_set_is_hydrated_update = await process_coro.__anext__()  # type: ignore
+        on_set_is_hydrated_update = await process_coro.__anext__()
         assert on_set_is_hydrated_update == StateUpdate(
         assert on_set_is_hydrated_update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {
@@ -1055,7 +1050,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         )
         # complete the processing
         # complete the processing
         with pytest.raises(StopAsyncIteration):
         with pytest.raises(StopAsyncIteration):
-            await process_coro.__anext__()  # type: ignore
+            await process_coro.__anext__()
 
 
         # a simple state update event should NOT trigger on_load or route var side effects
         # a simple state update event should NOT trigger on_load or route var side effects
         process_coro = process(
         process_coro = process(
@@ -1065,7 +1060,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             headers={},
             headers={},
             client_ip=client_ip,
             client_ip=client_ip,
         )
         )
-        update = await process_coro.__anext__()  # type: ignore
+        update = await process_coro.__anext__()
         assert update == StateUpdate(
         assert update == StateUpdate(
             delta={
             delta={
                 state.get_name(): {
                 state.get_name(): {
@@ -1079,7 +1074,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         )
         )
         # complete the processing
         # complete the processing
         with pytest.raises(StopAsyncIteration):
         with pytest.raises(StopAsyncIteration):
-            await process_coro.__anext__()  # type: ignore
+            await process_coro.__anext__()
 
 
         prev_exp_val = exp_val
         prev_exp_val = exp_val
     state = await app.state_manager.get_state(token)
     state = await app.state_manager.get_state(token)
@@ -1116,7 +1111,7 @@ async def test_process_events(mocker, token: str):
         token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
         token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
     )
     )
 
 
-    async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):  # type: ignore
+    async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
         pass
         pass
 
 
     assert (await app.state_manager.get_state(token)).value == 5
     assert (await app.state_manager.get_state(token)).value == 5

+ 165 - 11
tests/test_state.py

@@ -7,7 +7,7 @@ import functools
 import json
 import json
 import os
 import os
 import sys
 import sys
-from typing import Dict, Generator, List, Optional, Union
+from typing import Any, Dict, Generator, List, Optional, Union
 from unittest.mock import AsyncMock, Mock
 from unittest.mock import AsyncMock, Mock
 
 
 import pytest
 import pytest
@@ -24,6 +24,7 @@ from reflex.state import (
     LockExpiredError,
     LockExpiredError,
     MutableProxy,
     MutableProxy,
     RouterData,
     RouterData,
+    State,
     StateManager,
     StateManager,
     StateManagerMemory,
     StateManagerMemory,
     StateManagerRedis,
     StateManagerRedis,
@@ -1374,8 +1375,13 @@ def test_error_on_state_method_shadow():
     )
     )
 
 
 
 
-def test_state_with_invalid_yield():
-    """Test that an error is thrown when a state yields an invalid value."""
+@pytest.mark.asyncio
+async def test_state_with_invalid_yield(capsys):
+    """Test that an error is thrown when a state yields an invalid value.
+
+    Args:
+        capsys: Pytest fixture for capture standard streams.
+    """
 
 
     class StateWithInvalidYield(BaseState):
     class StateWithInvalidYield(BaseState):
         """A state that yields an invalid value."""
         """A state that yields an invalid value."""
@@ -1389,15 +1395,16 @@ def test_state_with_invalid_yield():
             yield 1
             yield 1
 
 
     invalid_state = StateWithInvalidYield()
     invalid_state = StateWithInvalidYield()
-    with pytest.raises(TypeError) as err:
-        invalid_state._check_valid(
-            invalid_state.event_handlers["invalid_handler"],
-            rx.event.Event(token="fake_token", name="invalid_handler"),
+    async for update in invalid_state._process(
+        rx.event.Event(token="fake_token", name="invalid_handler")
+    ):
+        assert not update.delta
+        assert update.events == rx.event.fix_events(
+            [rx.window_alert("An error occurred. See logs for details.")],
+            token="",
         )
         )
-    assert (
-        "must only return/yield: None, Events or other EventHandlers"
-        in err.value.args[0]
-    )
+    captured = capsys.readouterr()
+    assert "must only return/yield: None, Events or other EventHandlers" in captured.out
 
 
 
 
 @pytest.fixture(scope="function", params=["in_process", "redis"])
 @pytest.fixture(scope="function", params=["in_process", "redis"])
@@ -2303,3 +2310,150 @@ def test_state_union_optional():
     assert UnionState.custom_union.c2r is not None  # type: ignore
     assert UnionState.custom_union.c2r is not None  # type: ignore
     assert types.is_optional(UnionState.opt_int._var_type)  # type: ignore
     assert types.is_optional(UnionState.opt_int._var_type)  # type: ignore
     assert types.is_union(UnionState.int_float._var_type)  # type: ignore
     assert types.is_union(UnionState.int_float._var_type)  # type: ignore
+
+
+def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
+    """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
+
+    Args:
+        state: the State that is hydrated.
+        is_hydrated: whether the state is hydrated.
+
+    Returns:
+        dict similar to that returned by `State.get_delta` with IS_HYDRATED: is_hydrated
+    """
+    return {state.get_full_name(): {CompileVars.IS_HYDRATED: is_hydrated}}
+
+
+class OnLoadState(State):
+    """A test state with no return in handler."""
+
+    num: int = 0
+
+    def test_handler(self):
+        """Test handler."""
+        self.num += 1
+
+
+class OnLoadState2(State):
+    """A test state with return in handler."""
+
+    num: int = 0
+    name: str
+
+    def test_handler(self):
+        """Test handler that calls another handler.
+
+        Returns:
+            Chain of EventHandlers
+        """
+        self.num += 1
+        return self.change_name
+
+    def change_name(self):
+        """Test handler to change name."""
+        self.name = "random"
+
+
+class OnLoadState3(State):
+    """A test state with async handler."""
+
+    num: int = 0
+
+    async def test_handler(self):
+        """Test handler."""
+        self.num += 1
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "test_state, expected",
+    [
+        (OnLoadState, {"on_load_state": {"num": 1}}),
+        (OnLoadState2, {"on_load_state2": {"num": 1}}),
+        (OnLoadState3, {"on_load_state3": {"num": 1}}),
+    ],
+)
+async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
+    """Test that a state hydrate event is processed correctly.
+
+    Args:
+        app_module_mock: The app module that will be returned by get_app().
+        token: A token.
+        test_state: State to process event.
+        expected: Expected delta.
+        mocker: pytest mock object.
+    """
+    mocker.patch("reflex.state.State.class_subclasses", {test_state})
+    app = app_module_mock.app = App(
+        state=State, load_events={"index": [test_state.test_handler]}
+    )
+    state = State()
+
+    updates = []
+    async for update in rx.app.process(
+        app=app,
+        event=Event(
+            token=token,
+            name=f"{state.get_name()}.{CompileVars.ON_LOAD_INTERNAL}",
+            router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
+        ),
+        sid="sid",
+        headers={},
+        client_ip="",
+    ):
+        assert isinstance(update, StateUpdate)
+        updates.append(update)
+    assert len(updates) == 1
+    assert updates[0].delta == exp_is_hydrated(state, False)
+
+    events = updates[0].events
+    assert len(events) == 2
+    assert (await state._process(events[0]).__anext__()).delta == {
+        test_state.get_full_name(): {"num": 1}
+    }
+    assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state)
+
+
+@pytest.mark.asyncio
+async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
+    """Test that a state hydrate event for multiple on-load events is processed correctly.
+
+    Args:
+        app_module_mock: The app module that will be returned by get_app().
+        token: A token.
+        mocker: pytest mock object.
+    """
+    mocker.patch("reflex.state.State.class_subclasses", {OnLoadState})
+    app = app_module_mock.app = App(
+        state=State,
+        load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
+    )
+    state = State()
+
+    updates = []
+    async for update in rx.app.process(
+        app=app,
+        event=Event(
+            token=token,
+            name=f"{state.get_full_name()}.{CompileVars.ON_LOAD_INTERNAL}",
+            router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
+        ),
+        sid="sid",
+        headers={},
+        client_ip="",
+    ):
+        assert isinstance(update, StateUpdate)
+        updates.append(update)
+    assert len(updates) == 1
+    assert updates[0].delta == exp_is_hydrated(state, False)
+
+    events = updates[0].events
+    assert len(events) == 3
+    assert (await state._process(events[0]).__anext__()).delta == {
+        OnLoadState.get_full_name(): {"num": 1}
+    }
+    assert (await state._process(events[1]).__anext__()).delta == {
+        OnLoadState.get_full_name(): {"num": 2}
+    }
+    assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)