middleware.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """Middleware Mixin that allow to add middleware to the app."""
  2. from __future__ import annotations
  3. import asyncio
  4. import dataclasses
  5. from reflex.event import Event
  6. from reflex.middleware import HydrateMiddleware, Middleware
  7. from reflex.state import BaseState, StateUpdate
  8. from .mixin import AppMixin
  9. @dataclasses.dataclass
  10. class MiddlewareMixin(AppMixin):
  11. """Middleware Mixin that allow to add middleware to the app."""
  12. # Middleware to add to the app. Users should use `add_middleware`.
  13. _middlewares: list[Middleware] = dataclasses.field(default_factory=list)
  14. def _init_mixin(self):
  15. self._middlewares.append(HydrateMiddleware())
  16. def add_middleware(self, middleware: Middleware, index: int | None = None):
  17. """Add middleware to the app.
  18. Args:
  19. middleware: The middleware to add.
  20. index: The index to add the middleware at.
  21. """
  22. if index is None:
  23. self._middlewares.append(middleware)
  24. else:
  25. self._middlewares.insert(index, middleware)
  26. async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
  27. """Preprocess the event.
  28. This is where middleware can modify the event before it is processed.
  29. Each middleware is called in the order it was added to the app.
  30. If a middleware returns an update, the event is not processed and the
  31. update is returned.
  32. Args:
  33. state: The state to preprocess.
  34. event: The event to preprocess.
  35. Returns:
  36. An optional state to return.
  37. """
  38. for middleware in self._middlewares:
  39. if asyncio.iscoroutinefunction(middleware.preprocess):
  40. out = await middleware.preprocess(app=self, state=state, event=event) # pyright: ignore [reportArgumentType]
  41. else:
  42. out = middleware.preprocess(app=self, state=state, event=event) # pyright: ignore [reportArgumentType]
  43. if out is not None:
  44. return out # pyright: ignore [reportReturnType]
  45. async def _postprocess(
  46. self, state: BaseState, event: Event, update: StateUpdate
  47. ) -> StateUpdate:
  48. """Postprocess the event.
  49. This is where middleware can modify the delta after it is processed.
  50. Each middleware is called in the order it was added to the app.
  51. Args:
  52. state: The state to postprocess.
  53. event: The event to postprocess.
  54. update: The current state update.
  55. Returns:
  56. The state update to return.
  57. """
  58. out = update
  59. for middleware in self._middlewares:
  60. if asyncio.iscoroutinefunction(middleware.postprocess):
  61. out = await middleware.postprocess(
  62. app=self, # pyright: ignore [reportArgumentType]
  63. state=state,
  64. event=event,
  65. update=update,
  66. )
  67. else:
  68. out = middleware.postprocess(
  69. app=self, # pyright: ignore [reportArgumentType]
  70. state=state,
  71. event=event,
  72. update=update,
  73. )
  74. return out # pyright: ignore[reportReturnType]