Browse Source

set storage_secret in helpers.py

awfulness 1 year ago
parent
commit
bc1d174b1a
3 changed files with 17 additions and 14 deletions
  1. 14 0
      nicegui/helpers.py
  2. 1 9
      nicegui/run.py
  3. 2 5
      nicegui/run_with.py

+ 14 - 0
nicegui/helpers.py

@@ -9,6 +9,10 @@ import webbrowser
 from contextlib import nullcontext
 from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Union
 
+from starlette.middleware import Middleware
+from starlette.middleware.sessions import SessionMiddleware
+
+from nicegui.storage import RequestTrackingMiddleware
 from . import background_tasks, globals
 
 if TYPE_CHECKING:
@@ -84,3 +88,13 @@ def schedule_browser(host: str, port: int) -> Tuple[threading.Thread, threading.
     thread = threading.Thread(target=in_thread, args=(host, port), daemon=True)
     thread.start()
     return thread, cancel
+
+
+def set_storage_secret(storage_secret: Optional[str] = None) -> None:
+    """Set storage_secret for ui.run() and run_with."""
+    if any(m.cls == SessionMiddleware for m in globals.app.user_middleware):
+        # NOTE not using "add_middleware" because it would be the wrong order
+        globals.app.user_middleware.append(Middleware(RequestTrackingMiddleware))
+    elif storage_secret is not None:
+        globals.app.add_middleware(RequestTrackingMiddleware)
+        globals.app.add_middleware(SessionMiddleware, secret_key=storage_secret)

+ 1 - 9
nicegui/run.py

@@ -7,8 +7,6 @@ from typing import Any, List, Optional, Tuple
 
 import __main__
 import uvicorn
-from starlette.middleware import Middleware
-from starlette.middleware.sessions import SessionMiddleware
 from uvicorn.main import STARTUP_FAILURE
 from uvicorn.supervisors import ChangeReload, Multiprocess
 
@@ -16,7 +14,6 @@ from . import globals, helpers
 from . import native as native_module
 from . import native_mode
 from .language import Language
-from .storage import RequestTrackingMiddleware
 
 
 class Server(uvicorn.Server):
@@ -28,12 +25,7 @@ class Server(uvicorn.Server):
         if native_module.method_queue is not None:
             globals.app.native.main_window = native_module.WindowProxy()
 
-        if any(m.cls == SessionMiddleware for m in globals.app.user_middleware):
-            # NOTE not using "add_middleware" because it would be the wrong order
-            globals.app.user_middleware.append(Middleware(RequestTrackingMiddleware))
-        elif self.config.storage_secret is not None:
-            globals.app.add_middleware(RequestTrackingMiddleware)
-            globals.app.add_middleware(SessionMiddleware, secret_key=self.config.storage_secret)
+        helpers.set_storage_secret(self.config.storage_secret)
         super().run(sockets=sockets)
 
 

+ 2 - 5
nicegui/run_with.py

@@ -1,12 +1,11 @@
 from typing import Optional
 
 from fastapi import FastAPI
-from starlette.middleware.sessions import SessionMiddleware
 
 from nicegui import globals
+from nicegui.helpers import set_storage_secret
 from nicegui.language import Language
 from nicegui.nicegui import handle_shutdown, handle_startup
-from nicegui.storage import RequestTrackingMiddleware
 
 
 def run_with(
@@ -31,9 +30,7 @@ def run_with(
     globals.excludes = [e.strip() for e in exclude.split(',')]
     globals.tailwind = True
 
-    app.add_middleware(RequestTrackingMiddleware)
-    app.add_middleware(SessionMiddleware, secret_key=storage_secret)
-
+    set_storage_secret(storage_secret)
     app.on_event('startup')(lambda: handle_startup(with_welcome_message=False))
     app.on_event('shutdown')(lambda: handle_shutdown())