hydrate_middleware.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. """Middleware to hydrate the state."""
  2. from __future__ import annotations
  3. import dataclasses
  4. from typing import TYPE_CHECKING, ChainMap
  5. from reflex import constants
  6. from reflex.event import Event, get_hydrate_event, get_partial_hydrate_event
  7. from reflex.middleware.middleware import Middleware
  8. from reflex.state import BaseState, StateDelta, StateUpdate, _resolve_delta
  9. if TYPE_CHECKING:
  10. from reflex.app import App
  11. @dataclasses.dataclass(init=True)
  12. class HydrateMiddleware(Middleware):
  13. """Middleware to handle initial app hydration."""
  14. async def preprocess(
  15. self, app: App, state: BaseState, event: Event
  16. ) -> StateUpdate | None:
  17. """Preprocess the event.
  18. Args:
  19. app: The app to apply the middleware to.
  20. state: The client state.
  21. event: The event to preprocess.
  22. Returns:
  23. An optional delta or list of state updates to return.
  24. """
  25. # If this is not the hydrate event, return None
  26. if event.name != get_hydrate_event(state):
  27. return None
  28. # Clear client storage, to respect clearing cookies
  29. state._reset_client_storage()
  30. # Mark state as not hydrated (until on_loads are complete)
  31. setattr(state, constants.CompileVars.IS_HYDRATED, False)
  32. # Get the initial state.
  33. delta = await _resolve_delta(
  34. StateDelta(
  35. state.dict(),
  36. client_token=state.router.session.client_token,
  37. flush=True,
  38. )
  39. )
  40. # since a full dict was captured, clean any dirtiness
  41. state._clean()
  42. # Return the state update.
  43. return StateUpdate(delta=delta, events=[])
  44. @dataclasses.dataclass(init=True)
  45. class PartialHyderateMiddleware(Middleware):
  46. """Middleware to handle partial app hydration."""
  47. async def preprocess(
  48. self, app: App, state: BaseState, event: Event
  49. ) -> StateUpdate | None:
  50. """Preprocess the event.
  51. Args:
  52. app: The app to apply the middleware to."
  53. state: The client state.""
  54. event: The event to preprocess.""
  55. Returns:
  56. An optional delta or list of state updates to return.""
  57. """
  58. # If this is not the partial hydrate event, return None
  59. if event.name != get_partial_hydrate_event(state):
  60. return None
  61. substates_names = event.payload.get("states", [])
  62. if not substates_names:
  63. return None
  64. substates = [
  65. substate
  66. for substate_name in substates_names
  67. if (substate := state.get_substate(substate_name.split("."))) is not None
  68. ]
  69. delta = await _resolve_delta(
  70. StateDelta(
  71. ChainMap(*[substate.dict() for substate in substates]),
  72. client_token=state.router.session.client_token,
  73. flush=True,
  74. )
  75. )
  76. # since a full dict was captured, clean any dirtiness
  77. for substate in substates:
  78. substate._clean()
  79. # Return the state update.
  80. return StateUpdate(delta=delta, events=[])