Sfoglia il codice sorgente

split lifespan and middleware logic in separate mixin files (#3557)

* split lifespan and middleware logic in separate mixin files

* fix for 3.8

* fix for unit tests

* add missing sys import

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Thomas Brandého 11 mesi fa
parent
commit
f0bab665ce

+ 8 - 123
reflex/app.py

@@ -7,7 +7,6 @@ import concurrent.futures
 import contextlib
 import contextlib
 import copy
 import copy
 import functools
 import functools
-import inspect
 import io
 import io
 import multiprocessing
 import multiprocessing
 import os
 import os
@@ -40,6 +39,7 @@ from starlette_admin.contrib.sqla.view import ModelView
 
 
 from reflex import constants
 from reflex import constants
 from reflex.admin import AdminDash
 from reflex.admin import AdminDash
+from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
 from reflex.base import Base
 from reflex.base import Base
 from reflex.compiler import compiler
 from reflex.compiler import compiler
 from reflex.compiler import utils as compiler_utils
 from reflex.compiler import utils as compiler_utils
@@ -61,7 +61,6 @@ from reflex.components.core.upload import Upload, get_upload_dir
 from reflex.components.radix import themes
 from reflex.components.radix import themes
 from reflex.config import get_config
 from reflex.config import get_config
 from reflex.event import Event, EventHandler, EventSpec
 from reflex.event import Event, EventHandler, EventSpec
-from reflex.middleware import HydrateMiddleware, Middleware
 from reflex.model import Model
 from reflex.model import Model
 from reflex.page import (
 from reflex.page import (
     DECORATED_PAGES,
     DECORATED_PAGES,
@@ -108,50 +107,7 @@ class OverlayFragment(Fragment):
     pass
     pass
 
 
 
 
-class LifespanMixin(Base):
-    """A Mixin that allow tasks to run during the whole app lifespan."""
-
-    # Lifespan tasks that are planned to run.
-    lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
-
-    @contextlib.asynccontextmanager
-    async def _run_lifespan_tasks(self, app: FastAPI):
-        running_tasks = []
-        try:
-            async with contextlib.AsyncExitStack() as stack:
-                for task in self.lifespan_tasks:
-                    if isinstance(task, asyncio.Task):
-                        running_tasks.append(task)
-                    else:
-                        signature = inspect.signature(task)
-                        if "app" in signature.parameters:
-                            task = functools.partial(task, app=app)
-                        _t = task()
-                        if isinstance(_t, contextlib._AsyncGeneratorContextManager):
-                            await stack.enter_async_context(_t)
-                        elif isinstance(_t, Coroutine):
-                            running_tasks.append(asyncio.create_task(_t))
-                yield
-        finally:
-            cancel_kwargs = (
-                {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
-            )
-            for task in running_tasks:
-                task.cancel(**cancel_kwargs)
-
-    def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
-        """Register a task to run during the lifespan of the app.
-
-        Args:
-            task: The task to register.
-            task_kwargs: The kwargs of the task.
-        """
-        if task_kwargs:
-            task = functools.partial(task, **task_kwargs)  # type: ignore
-        self.lifespan_tasks.add(task)  # type: ignore
-
-
-class App(LifespanMixin, Base):
+class App(MiddlewareMixin, LifespanMixin, Base):
     """The main Reflex app that encapsulates the backend and frontend.
     """The main Reflex app that encapsulates the backend and frontend.
 
 
     Every Reflex app needs an app defined in its main module.
     Every Reflex app needs an app defined in its main module.
@@ -210,9 +166,6 @@ class App(LifespanMixin, Base):
     # Class to manage many client states.
     # Class to manage many client states.
     _state_manager: Optional[StateManager] = None
     _state_manager: Optional[StateManager] = None
 
 
-    # Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
-    middleware: List[Middleware] = []
-
     # Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
     # Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
     load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
     load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
 
 
@@ -253,14 +206,17 @@ class App(LifespanMixin, Base):
         if "breakpoints" in self.style:
         if "breakpoints" in self.style:
             set_breakpoints(self.style.pop("breakpoints"))
             set_breakpoints(self.style.pop("breakpoints"))
 
 
-        # Add middleware.
-        self.middleware.append(HydrateMiddleware())
-
         # Set up the API.
         # Set up the API.
         self.api = FastAPI(lifespan=self._run_lifespan_tasks)
         self.api = FastAPI(lifespan=self._run_lifespan_tasks)
         self._add_cors()
         self._add_cors()
         self._add_default_endpoints()
         self._add_default_endpoints()
 
 
+        for clz in App.__mro__:
+            if clz == App:
+                continue
+            if issubclass(clz, AppMixin):
+                clz._init_mixin(self)
+
         self._setup_state()
         self._setup_state()
 
 
         # Set up the admin dash.
         # Set up the admin dash.
@@ -385,77 +341,6 @@ class App(LifespanMixin, Base):
             raise ValueError("The state manager has not been initialized.")
             raise ValueError("The state manager has not been initialized.")
         return self._state_manager
         return self._state_manager
 
 
-    async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
-        """Preprocess the event.
-
-        This is where middleware can modify the event before it is processed.
-        Each middleware is called in the order it was added to the app.
-
-        If a middleware returns an update, the event is not processed and the
-        update is returned.
-
-        Args:
-            state: The state to preprocess.
-            event: The event to preprocess.
-
-        Returns:
-            An optional state to return.
-        """
-        for middleware in self.middleware:
-            if asyncio.iscoroutinefunction(middleware.preprocess):
-                out = await middleware.preprocess(app=self, state=state, event=event)  # type: ignore
-            else:
-                out = middleware.preprocess(app=self, state=state, event=event)  # type: ignore
-            if out is not None:
-                return out  # type: ignore
-
-    async def _postprocess(
-        self, state: BaseState, event: Event, update: StateUpdate
-    ) -> StateUpdate:
-        """Postprocess the event.
-
-        This is where middleware can modify the delta after it is processed.
-        Each middleware is called in the order it was added to the app.
-
-        Args:
-            state: The state to postprocess.
-            event: The event to postprocess.
-            update: The current state update.
-
-        Returns:
-            The state update to return.
-        """
-        for middleware in self.middleware:
-            if asyncio.iscoroutinefunction(middleware.postprocess):
-                out = await middleware.postprocess(
-                    app=self,  # type: ignore
-                    state=state,
-                    event=event,
-                    update=update,
-                )
-            else:
-                out = middleware.postprocess(
-                    app=self,  # type: ignore
-                    state=state,
-                    event=event,
-                    update=update,
-                )
-            if out is not None:
-                return out  # type: ignore
-        return update
-
-    def add_middleware(self, middleware: Middleware, index: int | None = None):
-        """Add middleware to the app.
-
-        Args:
-            middleware: The middleware to add.
-            index: The index to add the middleware at.
-        """
-        if index is None:
-            self.middleware.append(middleware)
-        else:
-            self.middleware.insert(index, middleware)
-
     @staticmethod
     @staticmethod
     def _generate_component(component: Component | ComponentCallable) -> Component:
     def _generate_component(component: Component | ComponentCallable) -> Component:
         """Generate a component from a callable.
         """Generate a component from a callable.

+ 5 - 0
reflex/app_mixins/__init__.py

@@ -0,0 +1,5 @@
+"""Reflex AppMixins package."""
+
+from .lifespan import LifespanMixin
+from .middleware import MiddlewareMixin
+from .mixin import AppMixin

+ 57 - 0
reflex/app_mixins/lifespan.py

@@ -0,0 +1,57 @@
+"""Mixin that allow tasks to run during the whole app lifespan."""
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import functools
+import inspect
+import sys
+from typing import Callable, Coroutine, Set, Union
+
+from fastapi import FastAPI
+
+from .mixin import AppMixin
+
+
+class LifespanMixin(AppMixin):
+    """A Mixin that allow tasks to run during the whole app lifespan."""
+
+    # Lifespan tasks that are planned to run.
+    lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
+
+    @contextlib.asynccontextmanager
+    async def _run_lifespan_tasks(self, app: FastAPI):
+        running_tasks = []
+        try:
+            async with contextlib.AsyncExitStack() as stack:
+                for task in self.lifespan_tasks:
+                    if isinstance(task, asyncio.Task):
+                        running_tasks.append(task)
+                    else:
+                        signature = inspect.signature(task)
+                        if "app" in signature.parameters:
+                            task = functools.partial(task, app=app)
+                        _t = task()
+                        if isinstance(_t, contextlib._AsyncGeneratorContextManager):
+                            await stack.enter_async_context(_t)
+                        elif isinstance(_t, Coroutine):
+                            running_tasks.append(asyncio.create_task(_t))
+                yield
+        finally:
+            cancel_kwargs = (
+                {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {}
+            )
+            for task in running_tasks:
+                task.cancel(**cancel_kwargs)
+
+    def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
+        """Register a task to run during the lifespan of the app.
+
+        Args:
+            task: The task to register.
+            task_kwargs: The kwargs of the task.
+        """
+        if task_kwargs:
+            task = functools.partial(task, **task_kwargs)  # type: ignore
+        self.lifespan_tasks.add(task)  # type: ignore

+ 93 - 0
reflex/app_mixins/middleware.py

@@ -0,0 +1,93 @@
+"""Middleware Mixin that allow to add middleware to the app."""
+
+from __future__ import annotations
+
+import asyncio
+from typing import List
+
+from reflex.event import Event
+from reflex.middleware import HydrateMiddleware, Middleware
+from reflex.state import BaseState, StateUpdate
+
+from .mixin import AppMixin
+
+
+class MiddlewareMixin(AppMixin):
+    """Middleware Mixin that allow to add middleware to the app."""
+
+    # Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
+    middleware: List[Middleware] = []
+
+    def _init_mixin(self):
+        self.middleware.append(HydrateMiddleware())
+
+    def add_middleware(self, middleware: Middleware, index: int | None = None):
+        """Add middleware to the app.
+
+        Args:
+            middleware: The middleware to add.
+            index: The index to add the middleware at.
+        """
+        if index is None:
+            self.middleware.append(middleware)
+        else:
+            self.middleware.insert(index, middleware)
+
+    async def _preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
+        """Preprocess the event.
+
+        This is where middleware can modify the event before it is processed.
+        Each middleware is called in the order it was added to the app.
+
+        If a middleware returns an update, the event is not processed and the
+        update is returned.
+
+        Args:
+            state: The state to preprocess.
+            event: The event to preprocess.
+
+        Returns:
+            An optional state to return.
+        """
+        for middleware in self.middleware:
+            if asyncio.iscoroutinefunction(middleware.preprocess):
+                out = await middleware.preprocess(app=self, state=state, event=event)  # type: ignore
+            else:
+                out = middleware.preprocess(app=self, state=state, event=event)  # type: ignore
+            if out is not None:
+                return out  # type: ignore
+
+    async def _postprocess(
+        self, state: BaseState, event: Event, update: StateUpdate
+    ) -> StateUpdate:
+        """Postprocess the event.
+
+        This is where middleware can modify the delta after it is processed.
+        Each middleware is called in the order it was added to the app.
+
+        Args:
+            state: The state to postprocess.
+            event: The event to postprocess.
+            update: The current state update.
+
+        Returns:
+            The state update to return.
+        """
+        for middleware in self.middleware:
+            if asyncio.iscoroutinefunction(middleware.postprocess):
+                out = await middleware.postprocess(
+                    app=self,  # type: ignore
+                    state=state,
+                    event=event,
+                    update=update,
+                )
+            else:
+                out = middleware.postprocess(
+                    app=self,  # type: ignore
+                    state=state,
+                    event=event,
+                    update=update,
+                )
+            if out is not None:
+                return out  # type: ignore
+        return update

+ 14 - 0
reflex/app_mixins/mixin.py

@@ -0,0 +1,14 @@
+"""Default mixin for all app mixins."""
+
+from reflex.base import Base
+
+
+class AppMixin(Base):
+    """Define the base class for all app mixins."""
+
+    def _init_mixin(self):
+        """Initialize the mixin.
+
+        Any App mixin can override this method to perform any initialization.
+        """
+        ...