瀏覽代碼

avoid cyclic imports

Falko Schindler 1 年之前
父節點
當前提交
2b00a67d9e
共有 4 個文件被更改,包括 16 次插入17 次删除
  1. 0 13
      nicegui/helpers.py
  2. 2 1
      nicegui/run.py
  3. 2 3
      nicegui/run_with.py
  4. 12 0
      nicegui/storage.py

+ 0 - 13
nicegui/helpers.py

@@ -16,11 +16,8 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generator, Optional,
 
 from fastapi import Request
 from fastapi.responses import StreamingResponse
-from starlette.middleware import Middleware
-from starlette.middleware.sessions import SessionMiddleware
 
 from . import background_tasks, globals  # pylint: disable=redefined-builtin
-from .storage import RequestTrackingMiddleware
 
 if TYPE_CHECKING:
     from .client import Client
@@ -125,16 +122,6 @@ def schedule_browser(host: str, port: int) -> Tuple[threading.Thread, threading.
     return thread, cancel
 
 
-def set_storage_secret(storage_secret: Optional[str] = None) -> None:
-    """Set storage_secret and add request tracking middleware."""
-    if any(m.cls == SessionMiddleware for m in globals.app.user_middleware):
-        # NOTE not using "add_middleware" because it would be the wrong order
-        globals.app.user_middleware.append(Middleware(RequestTrackingMiddleware))
-    elif storage_secret is not None:
-        globals.app.add_middleware(RequestTrackingMiddleware)
-        globals.app.add_middleware(SessionMiddleware, secret_key=storage_secret)
-
-
 def get_streaming_response(file: Path, request: Request) -> StreamingResponse:
     """Get a StreamingResponse for the given file and request."""
     file_size = file.stat().st_size

+ 2 - 1
nicegui/run.py

@@ -12,6 +12,7 @@ from uvicorn.main import STARTUP_FAILURE
 from uvicorn.supervisors import ChangeReload, Multiprocess
 
 from . import native_mode  # pylint: disable=redefined-builtin
+from . import storage  # pylint: disable=redefined-builtin
 from . import globals, helpers  # pylint: disable=redefined-builtin
 from . import native as native_module
 from .air import Air
@@ -36,7 +37,7 @@ class Server(uvicorn.Server):
             native_module.response_queue = self.config.response_queue
             globals.app.native.main_window = native_module.WindowProxy()
 
-        helpers.set_storage_secret(self.config.storage_secret)
+        storage.set_storage_secret(self.config.storage_secret)
         super().run(sockets=sockets)
 
 

+ 2 - 3
nicegui/run_with.py

@@ -3,8 +3,7 @@ from typing import Optional, Union
 
 from fastapi import FastAPI
 
-from nicegui import globals  # pylint: disable=redefined-builtin
-from nicegui.helpers import set_storage_secret
+from nicegui import globals, storage  # pylint: disable=redefined-builtin
 from nicegui.language import Language
 from nicegui.nicegui import handle_shutdown, handle_startup
 
@@ -49,7 +48,7 @@ def run_with(
     globals.tailwind = tailwind
     globals.prod_js = prod_js
 
-    set_storage_secret(storage_secret)
+    storage.set_storage_secret(storage_secret)
     app.on_event('startup')(lambda: handle_startup(with_welcome_message=False))
     app.on_event('shutdown')(handle_shutdown)
 

+ 12 - 0
nicegui/storage.py

@@ -6,7 +6,9 @@ from pathlib import Path
 from typing import Any, Dict, Iterator, Optional, Union
 
 import aiofiles
+from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+from starlette.middleware.sessions import SessionMiddleware
 from starlette.requests import Request
 from starlette.responses import Response
 
@@ -72,6 +74,16 @@ class RequestTrackingMiddleware(BaseHTTPMiddleware):
         return response
 
 
+def set_storage_secret(storage_secret: Optional[str] = None) -> None:
+    """Set storage_secret and add request tracking middleware."""
+    if any(m.cls == SessionMiddleware for m in globals.app.user_middleware):
+        # NOTE not using "add_middleware" because it would be the wrong order
+        globals.app.user_middleware.append(Middleware(RequestTrackingMiddleware))
+    elif storage_secret is not None:
+        globals.app.add_middleware(RequestTrackingMiddleware)
+        globals.app.add_middleware(SessionMiddleware, secret_key=storage_secret)
+
+
 class Storage:
 
     def __init__(self) -> None: