Browse Source

rewrote auth example to use new app.storage code
for this to work context/request needs to be passed into event handler

Rodja Trappe 2 years ago
parent
commit
e8919f9df6
6 changed files with 72 additions and 43 deletions
  1. 11 38
      examples/authentication/main.py
  2. 3 1
      nicegui/element.py
  3. 3 0
      nicegui/event_listener.py
  4. 10 3
      nicegui/storage.py
  5. 2 0
      tests/conftest.py
  6. 43 1
      tests/test_storage.py

+ 11 - 38
examples/authentication/main.py

@@ -1,71 +1,44 @@
 #!/usr/bin/env python3
-'''This is only a very simple authentication example which stores session IDs in memory and does not do any password hashing.
+'''This is a just very simple authentication example.
 
 Please see the `OAuth2 example at FastAPI <https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/>`_  or
-use the great `Authlib package <https://docs.authlib.org/en/v0.13/client/starlette.html#using-fastapi>`_ to implement a real authentication system.
-
+use the great `Authlib package <https://docs.authlib.org/en/v0.13/client/starlette.html#using-fastapi>`_ to implement a classing real authentication system.
 Here we just demonstrate the NiceGUI integration.
 '''
 
-import os
-import uuid
-from typing import Dict
-
-from fastapi import Request
 from fastapi.responses import RedirectResponse
-from starlette.middleware.sessions import SessionMiddleware
 
 from nicegui import app, ui
 
-# put your your own secret key in an environment variable MY_SECRET_KEY
-app.add_middleware(SessionMiddleware, secret_key=os.environ.get('MY_SECRET_KEY', ''))
-
-# in reality users and session_info would be persistent (e.g. database, file, ...) and passwords obviously hashed
+# in reality users passwords would obviously need to be hashed
 users = [('user1', 'pass1'), ('user2', 'pass2')]
-session_info: Dict[str, Dict] = {}
-
-
-def is_authenticated(request: Request) -> bool:
-    return session_info.get(request.session.get('id'), {}).get('authenticated', False)
 
 
 @ui.page('/')
-def main_page(request: Request) -> None:
-    if not is_authenticated(request):
+def main_page() -> None:
+    if not app.storage.individual.get('authenticated', False):
         return RedirectResponse('/login')
-    session = session_info[request.session['id']]
     with ui.column().classes('absolute-center items-center'):
-        ui.label(f'Hello {session["username"]}!').classes('text-2xl')
-        # NOTE we navigate to a new page here to be able to modify the session cookie (it is only editable while a request is en-route)
-        # see https://github.com/zauberzeug/nicegui/issues/527 for more details
-        ui.button('', on_click=lambda: ui.open('/logout')).props('outline round icon=logout')
+        ui.label(f'Hello {app.storage.individual["username"]}!').classes('text-2xl')
+        ui.button('', on_click=lambda: (app.storage.individual.clear(), ui.open('/login'))) \
+            .props('outline round icon=logout')
 
 
 @ui.page('/login')
-def login(request: Request) -> None:
+def login() -> None:
     def try_login() -> None:  # local function to avoid passing username and password as arguments
         if (username.value, password.value) in users:
-            session_info[request.session['id']] = {'username': username.value, 'authenticated': True}
+            app.storage.individual.update({'username': username.value, 'authenticated': True})
             ui.open('/')
         else:
             ui.notify('Wrong username or password', color='negative')
 
-    if is_authenticated(request):
+    if app.storage.individual.get('authenticated', False):
         return RedirectResponse('/')
-    request.session['id'] = str(uuid.uuid4())  # NOTE this stores a new session ID in the cookie of the client
     with ui.card().classes('absolute-center'):
         username = ui.input('Username').on('keydown.enter', try_login)
         password = ui.input('Password').props('type=password').on('keydown.enter', try_login)
         ui.button('Log in', on_click=try_login)
 
 
-@ui.page('/logout')
-def logout(request: Request) -> None:
-    if is_authenticated(request):
-        session_info.pop(request.session['id'])
-        request.session['id'] = None
-        return RedirectResponse('/login')
-    return RedirectResponse('/')
-
-
 ui.run()

+ 3 - 1
nicegui/element.py

@@ -9,7 +9,7 @@ from typing_extensions import Self
 
 from nicegui import json
 
-from . import binding, events, globals, outbox
+from . import binding, events, globals, outbox, storage
 from .elements.mixins.visibility import Visibility
 from .event_listener import EventListener
 from .slot import Slot
@@ -230,6 +230,7 @@ class Element(Visibility):
                 throttle=throttle,
                 leading_events=leading_events,
                 trailing_events=trailing_events,
+                request=storage.request_contextvar.get()
             )
             self._event_listeners[listener.id] = listener
             self.update()
@@ -237,6 +238,7 @@ class Element(Visibility):
 
     def _handle_event(self, msg: Dict) -> None:
         listener = self._event_listeners[msg['listener_id']]
+        storage.request_contextvar.set(listener.request)
         events.handle_event(listener.handler, msg, sender=self)
 
     def update(self) -> None:

+ 3 - 0
nicegui/event_listener.py

@@ -2,6 +2,8 @@ import uuid
 from dataclasses import dataclass, field
 from typing import Any, Callable, Dict, List
 
+from fastapi import Request
+
 from .helpers import KWONLY_SLOTS
 
 
@@ -15,6 +17,7 @@ class EventListener:
     throttle: float
     leading_events: bool
     trailing_events: bool
+    request: Request
 
     def __post_init__(self) -> None:
         self.id = str(uuid.uuid4())

+ 10 - 3
nicegui/storage.py

@@ -37,6 +37,7 @@ class PersistentDict(dict):
         self.lock = threading.Lock()
         self.load()
         self.update(*arg, **kw)
+        self.modified = bool(arg or kw)
 
     def load(self):
         with self.lock:
@@ -50,26 +51,32 @@ class PersistentDict(dict):
     def __setitem__(self, key, value):
         with self.lock:
             super().__setitem__(key, value)
+            self.modified = True
 
     def __delitem__(self, key):
         with self.lock:
             super().__delitem__(key)
+            self.modified = True
+
+    def clear(self):
+        with self.lock:
+            super().clear()
+            self.modified = True
 
     async def backup(self):
         data = dict(self)
-        if data:
+        if self.modified:
             async with aiofiles.open(self.filename, 'w') as f:
                 await f.write(json.dumps(data))
 
 
 class RequestTrackingMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
+        request_contextvar.set(request)
         if 'id' not in request.session:
             request.session['id'] = str(uuid.uuid4())
         request.state.responded = False
-        token = request_contextvar.set(request)
         response = await call_next(request)
-        request_contextvar.reset(token)
         request.state.responded = True
         return response
 

+ 2 - 0
tests/conftest.py

@@ -45,6 +45,8 @@ async def reset_globals() -> Generator[None, None, None]:
     globals.app.storage.general.clear()
     globals.app.storage._individuals.clear()
     await globals.app.storage.backup()
+    assert globals.app.storage._individuals.filename.read_text() == '{}'
+    assert globals.app.storage.general.filename.read_text() == '{}'
     globals.index_client = Client(page('/'), shared=True).__enter__()
     globals.app.get('/')(globals.index_client.build_response)
 

+ 43 - 1
tests/test_storage.py

@@ -1,6 +1,6 @@
 import asyncio
 
-from nicegui import Client, app, ui
+from nicegui import Client, app, background_tasks, ui
 
 from .screen import Screen
 
@@ -70,6 +70,48 @@ def test_individual_storage_modifications(screen: Screen):
     screen.should_contain('3')
 
 
+async def test_access_individual_storage_on_interaction(screen: Screen):
+    @ui.page('/')
+    async def page():
+        if 'test_switch' not in app.storage.individual:
+            app.storage.individual['test_switch'] = False
+        ui.switch('switch').bind_value(app.storage.individual, 'test_switch')
+
+    screen.open('/')
+    screen.click('switch')
+    screen.wait(1)
+    await app.storage.backup()
+    assert '{"test_switch": true}' in app.storage._individuals.filename.read_text()
+
+
+def test_access_individual_storage_from_button_click_handler(screen: Screen):
+    @ui.page('/')
+    async def page():
+        async def inner():
+            app.storage.individual['inner_function'] = 'works'
+            await app.storage.backup()
+
+        ui.button('test', on_click=inner)
+
+    screen.open('/')
+    screen.click('test')
+    screen.wait(1)
+    assert '{"inner_function": "works"}' in app.storage._individuals.filename.read_text()
+
+
+async def test_access_individual_storage_from_background_task(screen: Screen):
+    @ui.page('/')
+    def page():
+        async def subtask():
+            await asyncio.sleep(0.1)
+            app.storage.individual['subtask'] = 'works'
+            await app.storage.backup()
+        background_tasks.create(subtask())
+
+    screen.open('/')
+    assert '{"subtask": "works"}' in app.storage._individuals.filename.read_text()
+
+
 def test_individual_and_general_storage_is_persisted(screen: Screen):
     @ui.page('/')
     def page():