Browse Source

split lifespan and middleware logic in separate mixin files (#3557)

* split lifespan and middleware logic in separate mixin files

* fix for 3.8

* fix for unit tests

* add missing sys import

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Thomas Brandého 11 tháng trước cách đây
mục cha
commit
f0bab665ce

+ 8 - 123
reflex/app.py

@@ -7,7 +7,6 @@ import concurrent.futures
 import contextlib
 import copy
 import functools
-import inspect
 import io
 import multiprocessing
 import os
@@ -40,6 +39,7 @@ from starlette_admin.contrib.sqla.view import ModelView
 
 from reflex import constants
 from reflex.admin import AdminDash
+from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
 from reflex.base import Base
 from reflex.compiler import compiler
 from reflex.compiler import utils as compiler_utils
@@ -61,7 +61,6 @@ from reflex.components.core.upload import Upload, get_upload_dir
 from reflex.components.radix import themes
 from reflex.config import get_config
 from reflex.event import Event, EventHandler, EventSpec
-from reflex.middleware import HydrateMiddleware, Middleware
 from reflex.model import Model
 from reflex.page import (
     DECORATED_PAGES,
@@ -108,50 +107,7 @@ class OverlayFragment(Fragment):
     pass
 
 
-class LifespanMixin(Base):
-    """A Mixin that allow tasks to run during the whole app lifespan."""
-
-    # Lifespan tasks that are planned to run.
-    lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
-
-    @contextlib.asynccontextmanager
-    async def _run_lifespan_tasks(self, app: FastAPI):
-        running_tasks = []
-        try:
-            async with contextlib.AsyncExitStack() as stack:
-                for task in self.lifespan_tasks:
-                    if isinstance(task, asyncio.Task):
-                        running_tasks.append(task)
-                    else:
-                        signature = inspect.signature(task)
-                        if "app" in signature.parameters:
-                            task = functools.partial(task, app=app)
-                        _t = task()
-                        if isinstance(_t, contextlib._AsyncGeneratorContextManager):
-                            await stack.enter_async_context(_t)
-                        elif isinstance(_t, Coroutine):
-                            running_tasks.append(asyncio.create_task(_t))
-                yield
-        finally:
-            cancel_kwargs = (
-                {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
-            )
-            for task in running_tasks:
-                task.cancel(**cancel_kwargs)
-
-    def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
-        """Register a task to run during the lifespan of the app.
-
-        Args:
-            task: The task to register.
-            task_kwargs: The kwargs of the task.
-        """
-        if task_kwargs:
-            task = functools.partial(task, **task_kwargs)  # type: ignore
-        self.lifespan_tasks.add(task)  # type: ignore
-
-
-class App(LifespanMixin, Base):
+class App(MiddlewareMixin, LifespanMixin, Base):
     """The main Reflex app that encapsulates the backend and frontend.
 
     Every Reflex app needs an app defined in its main module.
@@ -210,9 +166,6 @@ class App(LifespanMixin, Base):
     # Class to manage many client states.
     _state_manager: Optional[StateManager] = None
 
-    # Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
-    middleware: List[Middleware] = []
-
     # Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
     load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
 
@@ -253,14 +206,17 @@ class App(LifespanMixin, Base):
         if "breakpoints" in self.style:
             set_breakpoints(self.style.pop("breakpoints"))
 
-        # Add middleware.
-        self.middleware.append(HydrateMiddleware())
-
         # Set up the API.
         self.api = FastAPI(lifespan=self._run_lifespan_tasks)
         self._add_cors()
         self._add_default_endpoints()
 
+        for clz in App.__mro__:
+            if clz == App:
+                continue
+            if issubclass(clz, AppMixin):
+                clz._init_mixin(self)
+
         self._setup_state()
 
         # Set up the admin dash.
@@ -385,77 +341,6 @@ class App(LifespanMixin, Base):
             raise ValueError("The state manager has not been initialized.")
         return self._state_manager
 
-    async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
-        """Preprocess the event.
-
-        This is where middleware can modify the event before it is processed.
-        Each middleware is called in the order it was added to the app.
-
-        If a middleware returns an update, the event is not processed and the
-        update is returned.
-
-        Args:
-            state: The state to preprocess.
-            event: The event to preprocess.
-
-        Returns:
-            An optional state to return.
-        """
-        for middleware in self.middleware:
-            if asyncio.iscoroutinefunction(middleware.preprocess):
-                out = await middleware.preprocess(app=self, state=state, event=event)  # type: ignore
-            else:
-                out = middleware.preprocess(app=self, state=state, event=event)  # type: ignore
-            if out is not None:
-                return out  # type: ignore
-
-    async def _postprocess(
-        self, state: BaseState, event: Event, update: StateUpdate
-    ) -> StateUpdate:
-        """Postprocess the event.
-
-        This is where middleware can modify the delta after it is processed.
-        Each middleware is called in the order it was added to the app.
-
-        Args:
-            state: The state to postprocess.
-            event: The event to postprocess.
-            update: The current state update.
-
-        Returns:
-            The state update to return.
-        """
-        for middleware in self.middleware:
-            if asyncio.iscoroutinefunction(middleware.postprocess):
-                out = await middleware.postprocess(
-                    app=self,  # type: ignore
-                    state=state,
-                    event=event,
-                    update=update,
-                )
-            else:
-                out = middleware.postprocess(
-                    app=self,  # type: ignore
-                    state=state,
-                    event=event,
-                    update=update,
-                )
-            if out is not None:
-                return out  # type: ignore
-        return update
-
-    def add_middleware(self, middleware: Middleware, index: int | None = None):
-        """Add middleware to the app.
-
-        Args:
-            middleware: The middleware to add.
-            index: The index to add the middleware at.
-        """
-        if index is None:
-            self.middleware.append(middleware)
-        else:
-            self.middleware.insert(index, middleware)
-
     @staticmethod
     def _generate_component(component: Component | ComponentCallable) -> Component:
         """Generate a component from a callable.

+ 5 - 0
reflex/app_mixins/__init__.py

@@ -0,0 +1,5 @@
+"""Reflex AppMixins package."""
+
+from .lifespan import LifespanMixin
+from .middleware import MiddlewareMixin
+from .mixin import AppMixin

+ 57 - 0
reflex/app_mixins/lifespan.py

@@ -0,0 +1,57 @@
+"""Mixin that allow tasks to run during the whole app lifespan."""
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import functools
+import inspect
+import sys
+from typing import Callable, Coroutine, Set, Union
+
+from fastapi import FastAPI
+
+from .mixin import AppMixin
+
+
+class LifespanMixin(AppMixin):
+    """A Mixin that allow tasks to run during the whole app lifespan."""
+
+    # Lifespan tasks that are planned to run.
+    lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
+
+    @contextlib.asynccontextmanager
+    async def _run_lifespan_tasks(self, app: FastAPI):
+        running_tasks = []
+        try:
+            async with contextlib.AsyncExitStack() as stack:
+                for task in self.lifespan_tasks:
+                    if isinstance(task, asyncio.Task):
+                        running_tasks.append(task)
+                    else:
+                        signature = inspect.signature(task)
+                        if "app" in signature.parameters:
+                            task = functools.partial(task, app=app)
+                        _t = task()
+                        if isinstance(_t, contextlib._AsyncGeneratorContextManager):
+                            await stack.enter_async_context(_t)
+                        elif isinstance(_t, Coroutine):
+                            running_tasks.append(asyncio.create_task(_t))
+                yield
+        finally:
+            cancel_kwargs = (
+                {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
+            )
+            for task in running_tasks:
+                task.cancel(**cancel_kwargs)
+
+    def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
+        """Register a task to run during the lifespan of the app.
+
+        Args:
+            task: The task to register.
+            task_kwargs: The kwargs of the task.
+        """
+        if task_kwargs:
+            task = functools.partial(task, **task_kwargs)  # type: ignore
+        self.lifespan_tasks.add(task)  # type: ignore

+ 93 - 0
reflex/app_mixins/middleware.py

@@ -0,0 +1,93 @@
+"""Middleware Mixin that allow to add middleware to the app."""
+
+from __future__ import annotations
+
+import asyncio
+from typing import List
+
+from reflex.event import Event
+from reflex.middleware import HydrateMiddleware, Middleware
+from reflex.state import BaseState, StateUpdate
+
+from .mixin import AppMixin
+
+
+class MiddlewareMixin(AppMixin):
+    """Middleware Mixin that allow to add middleware to the app."""
+
+    # Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
+    middleware: List[Middleware] = []
+
+    def _init_mixin(self):
+        self.middleware.append(HydrateMiddleware())
+
+    def add_middleware(self, middleware: Middleware, index: int | None = None):
+        """Add middleware to the app.
+
+        Args:
+            middleware: The middleware to add.
+            index: The index to add the middleware at.
+        """
+        if index is None:
+            self.middleware.append(middleware)
+        else:
+            self.middleware.insert(index, middleware)
+
+    async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
+        """Preprocess the event.
+
+        This is where middleware can modify the event before it is processed.
+        Each middleware is called in the order it was added to the app.
+
+        If a middleware returns an update, the event is not processed and the
+        update is returned.
+
+        Args:
+            state: The state to preprocess.
+            event: The event to preprocess.
+
+        Returns:
+            An optional state to return.
+        """
+        for middleware in self.middleware:
+            if asyncio.iscoroutinefunction(middleware.preprocess):
+                out = await middleware.preprocess(app=self, state=state, event=event)  # type: ignore
+            else:
+                out = middleware.preprocess(app=self, state=state, event=event)  # type: ignore
+            if out is not None:
+                return out  # type: ignore
+
+    async def _postprocess(
+        self, state: BaseState, event: Event, update: StateUpdate
+    ) -> StateUpdate:
+        """Postprocess the event.
+
+        This is where middleware can modify the delta after it is processed.
+        Each middleware is called in the order it was added to the app.
+
+        Args:
+            state: The state to postprocess.
+            event: The event to postprocess.
+            update: The current state update.
+
+        Returns:
+            The state update to return.
+        """
+        for middleware in self.middleware:
+            if asyncio.iscoroutinefunction(middleware.postprocess):
+                out = await middleware.postprocess(
+                    app=self,  # type: ignore
+                    state=state,
+                    event=event,
+                    update=update,
+                )
+            else:
+                out = middleware.postprocess(
+                    app=self,  # type: ignore
+                    state=state,
+                    event=event,
+                    update=update,
+                )
+            if out is not None:
+                return out  # type: ignore
+        return update

+ 14 - 0
reflex/app_mixins/mixin.py

@@ -0,0 +1,14 @@
+"""Default mixin for all app mixins."""
+
+from reflex.base import Base
+
+
+class AppMixin(Base):
+    """Define the base class for all app mixins."""
+
+    def _init_mixin(self):
+        """Initialize the mixin.
+
+        Any App mixin can override this method to perform any initialization.
+        """
+        ...