Browse Source

custom middleware allows session access in fastAPI functions, also moved the session from globals to ui with contextvar

Rodja Trappe 2 năm trước cách đây
mục cha
commit
db91f449cb
6 tập tin đã thay đổi với 41 bổ sung10 xóa
  1. 0 4
      nicegui/globals.py
  2. 0 2
      nicegui/page.py
  3. 2 0
      nicegui/run.py
  4. 16 0
      nicegui/session.py
  5. 3 0
      nicegui/ui.py
  6. 20 4
      tests/test_session.py

+ 0 - 4
nicegui/globals.py

@@ -87,10 +87,6 @@ def get_client() -> 'Client':
     return get_slot().parent.client
 
 
-def get_request() -> Request:
-    return requests[_socket_id]
-
-
 @contextmanager
 def socket_id(id: str) -> None:
     global _socket_id

+ 0 - 2
nicegui/page.py

@@ -68,7 +68,6 @@ class page:
 
         async def decorated(*dec_args, **dec_kwargs) -> Response:
             request = dec_kwargs['request']
-            globals.requests[globals._socket_id] = request
             # NOTE cleaning up the keyword args so the signature is consistent with "func" again
             dec_kwargs = {k: v for k, v in dec_kwargs.items() if k in parameters_of_decorated_func}
             with Client(self) as client:
@@ -88,7 +87,6 @@ class page:
                 result = task.result() if task.done() else None
             if isinstance(result, Response):  # NOTE if setup returns a response, we don't need to render the page
                 return result
-            del globals.requests[globals._socket_id]
             return client.build_response(request)
 
         parameters = [p for p in inspect.signature(func).parameters.values() if p.name != 'client']

+ 2 - 0
nicegui/run.py

@@ -12,11 +12,13 @@ from uvicorn.supervisors import ChangeReload, Multiprocess
 
 from . import globals, helpers, native_mode
 from .language import Language
+from .session import RequestTrackingMiddleware
 
 
 class Server(uvicorn.Server):
 
     def run(self, sockets: List[Any] = None) -> None:
+        globals.app.add_middleware(RequestTrackingMiddleware)
         globals.app.add_middleware(SessionMiddleware, secret_key='some_random_string')  # TODO real random string
         super().run(sockets=sockets)
 

+ 16 - 0
nicegui/session.py

@@ -0,0 +1,16 @@
+from fastapi import Request
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.types import ASGIApp
+
+from . import globals, ui
+
+
+class RequestTrackingMiddleware(BaseHTTPMiddleware):
+    def __init__(self, app: ASGIApp) -> None:
+        super().__init__(app)
+        self.requests = {}
+
+    async def dispatch(self, request: Request, call_next):
+        ui.session.set(request.session)
+        response = await call_next(request)
+        return response

+ 3 - 0
nicegui/ui.py

@@ -1,3 +1,4 @@
+import contextvars
 import os
 
 __all__ = [
@@ -167,6 +168,8 @@ from .page_layout import RightDrawer as right_drawer
 from .run import run
 from .run_with import run_with
 
+session = contextvars.ContextVar('session_var')
+
 if os.environ.get('MATPLOTLIB', 'true').lower() == 'true':
     from .elements.line_plot import LinePlot as line_plot
     from .elements.pyplot import Pyplot as pyplot

+ 20 - 4
tests/test_session.py

@@ -1,13 +1,19 @@
-from nicegui import globals, ui
+import requests
 
-from .screen import Screen
+from nicegui import app, globals, ui
+
+from .screen import PORT, Screen
 
 
 def test_session_data_is_stored_in_the_browser(screen: Screen):
     @ui.page('/')
     def page():
-        globals.get_request().session['count'] = globals.get_request().session.get('count', 0) + 1
-        ui.label(globals.get_request().session['count'])
+        ui.session.get()['count'] = ui.session.get().get('count', 0) + 1
+        ui.label(ui.session.get()['count'] or 'no session')
+
+    @app.get('/session')
+    def session():
+        return 'count = ' + str(ui.session.get()['count'])
 
     screen.open('/')
     screen.should_contain('1')
@@ -15,3 +21,13 @@ def test_session_data_is_stored_in_the_browser(screen: Screen):
     screen.should_contain('2')
     screen.open('/')
     screen.should_contain('3')
+    screen.open('/session')
+    screen.should_contain('count = 3')
+    # assert screen.selenium.g(f'http://localhost:{PORT}/session').json() == 3
+
+    #     ui.input('name').bind_value(request.session, 'key')
+
+    # screen.open('/')
+    # screen.find('input').send_keys('some text')
+    # screen.open('/')
+    # screen.should_contain('some text')