|
@@ -3,10 +3,10 @@
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
import dataclasses
|
|
import dataclasses
|
|
-from typing import TYPE_CHECKING
|
|
|
|
|
|
+from typing import TYPE_CHECKING, ChainMap
|
|
|
|
|
|
from reflex import constants
|
|
from reflex import constants
|
|
-from reflex.event import Event, get_hydrate_event
|
|
|
|
|
|
+from reflex.event import Event, get_hydrate_event, get_partial_hydrate_event
|
|
from reflex.middleware.middleware import Middleware
|
|
from reflex.middleware.middleware import Middleware
|
|
from reflex.state import BaseState, StateDelta, StateUpdate, _resolve_delta
|
|
from reflex.state import BaseState, StateDelta, StateUpdate, _resolve_delta
|
|
|
|
|
|
@@ -54,3 +54,50 @@ class HydrateMiddleware(Middleware):
|
|
|
|
|
|
# Return the state update.
|
|
# Return the state update.
|
|
return StateUpdate(delta=delta, events=[])
|
|
return StateUpdate(delta=delta, events=[])
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@dataclasses.dataclass(init=True)
|
|
|
|
+class PartialHyderateMiddleware(Middleware):
|
|
|
|
+ """Middleware to handle partial app hydration."""
|
|
|
|
+
|
|
|
|
+ async def preprocess(
|
|
|
|
+ self, app: App, state: BaseState, event: Event
|
|
|
|
+ ) -> StateUpdate | None:
|
|
|
|
+ """Preprocess the event.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ app: The app to apply the middleware to."
|
|
|
|
+ state: The client state.""
|
|
|
|
+ event: The event to preprocess.""
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ An optional delta or list of state updates to return.""
|
|
|
|
+ """
|
|
|
|
+ # If this is not the partial hydrate event, return None
|
|
|
|
+ if event.name != get_partial_hydrate_event(state):
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ substates_names = event.payload.get("states", [])
|
|
|
|
+ if not substates_names:
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ substates = [
|
|
|
|
+ substate
|
|
|
|
+ for substate_name in substates_names
|
|
|
|
+ if (substate := state.get_substate(substate_name)) is not None
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ delta = await _resolve_delta(
|
|
|
|
+ StateDelta(
|
|
|
|
+ ChainMap(*[substate.dict() for substate in substates]),
|
|
|
|
+ client_token=state.router.session.client_token,
|
|
|
|
+ flush=True,
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # since a full dict was captured, clean any dirtiness
|
|
|
|
+ for substate in substates:
|
|
|
|
+ substate._clean()
|
|
|
|
+
|
|
|
|
+ # Return the state update.
|
|
|
|
+ return StateUpdate(delta=delta, events=[])
|