浏览代码

introducing storage.general and storage.individual
which are persisted as json files on the server

Rodja Trappe 2 年之前
父节点
当前提交
60f8c35982
共有 5 个文件被更改,包括 135 次插入19 次删除
  1. 1 2
      .gitignore
  2. 1 0
      nicegui/nicegui.py
  3. 78 3
      nicegui/storage.py
  4. 5 1
      tests/conftest.py
  5. 50 13
      tests/test_storage.py

+ 1 - 2
.gitignore

@@ -5,7 +5,6 @@ dist
 /test.py
 /test.py
 *.pickle
 *.pickle
 tests/screenshots/
 tests/screenshots/
-
-# ignore local virtual environments
 venv
 venv
 .idea
 .idea
+.nicegui/

+ 1 - 0
nicegui/nicegui.py

@@ -73,6 +73,7 @@ def handle_startup(with_welcome_message: bool = True) -> None:
             safe_invoke(t)
             safe_invoke(t)
     background_tasks.create(binding.loop())
     background_tasks.create(binding.loop())
     background_tasks.create(outbox.loop())
     background_tasks.create(outbox.loop())
+    background_tasks.create(app.storage._loop())
     background_tasks.create(prune_clients())
     background_tasks.create(prune_clients())
     background_tasks.create(prune_slot_stacks())
     background_tasks.create(prune_slot_stacks())
     globals.state = globals.State.STARTED
     globals.state = globals.State.STARTED

+ 78 - 3
nicegui/storage.py

@@ -1,6 +1,12 @@
+import asyncio
 import contextvars
 import contextvars
+import json
+import threading
+import uuid
+from pathlib import Path
 from typing import Dict
 from typing import Dict
 
 
+import aiofiles
 from fastapi import Request
 from fastapi import Request
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.base import BaseHTTPMiddleware
 
 
@@ -25,20 +31,63 @@ class ReadOnlyDict:
         raise TypeError(self._write_error_message)
         raise TypeError(self._write_error_message)
 
 
 
 
+class PersistentDict(dict):
+    def __init__(self, filename: Path, *arg, **kw):
+        self.filename = filename
+        self.lock = threading.Lock()
+        self.load()
+        self.update(*arg, **kw)
+
+    def load(self):
+        with self.lock:
+            if self.filename.exists():
+                with open(self.filename, 'r') as f:
+                    try:
+                        self.update(json.load(f))
+                    except json.JSONDecodeError:
+                        pass
+
+    def __setitem__(self, key, value):
+        with self.lock:
+            super().__setitem__(key, value)
+
+    def __delitem__(self, key):
+        with self.lock:
+            super().__delitem__(key)
+
+    async def backup(self):
+        data = dict(self)
+        if data:
+            async with aiofiles.open(self.filename, 'w') as f:
+                await f.write(json.dumps(data))
+
+
 class RequestTrackingMiddleware(BaseHTTPMiddleware):
 class RequestTrackingMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
     async def dispatch(self, request: Request, call_next):
-        token = request_contextvar.set(request)
+        if 'id' not in request.session:
+            request.session['id'] = str(uuid.uuid4())
         request.state.responded = False
         request.state.responded = False
+        token = request_contextvar.set(request)
         response = await call_next(request)
         response = await call_next(request)
-        request.state.responded = True
         request_contextvar.reset(token)
         request_contextvar.reset(token)
+        request.state.responded = True
         return response
         return response
 
 
 
 
 class Storage:
 class Storage:
 
 
+    def __init__(self):
+        self.storage_dir = Path('.nicegui')
+        self.storage_dir.mkdir(exist_ok=True)
+        self._general = PersistentDict(self.storage_dir / 'storage_general.json')
+        self._individuals = PersistentDict(self.storage_dir / 'storage_individuals.json')
+
     @property
     @property
-    def session(self) -> Dict:
+    def browser(self) -> Dict:
+        """Small storage that is saved directly within the user's browser (encrypted cookie).
+
+        The data is shared between all browser tab and can only be modified before the initial request has been submitted.
+        Normally it is better to use `app.storage.individual` instead to reduce payload, improved security and larger storage capacity)."""
         request: Request = request_contextvar.get()
         request: Request = request_contextvar.get()
         if request.state.responded:
         if request.state.responded:
             return ReadOnlyDict(
             return ReadOnlyDict(
@@ -46,3 +95,29 @@ class Storage:
                 'the response to the browser has already been build so modifications can not be send back anymore'
                 'the response to the browser has already been build so modifications can not be send back anymore'
             )
             )
         return request.session
         return request.session
+
+    @property
+    def individual(self) -> Dict:
+        """Individual user storage that is persisted on the server.
+
+        The data is stored in a file on the server.
+        It is shared between all browser tabs by identifying the user via session cookie id.
+        """
+        request: Request = request_contextvar.get()
+        if request.session['id'] not in self._individuals:
+            self._individuals[request.session['id']] = {}
+        return self._individuals[request.session['id']]
+
+    @property
+    def general(self) -> Dict:
+        """General storage shared between all users that is persisted on the server."""
+        return self._general
+
+    async def backup(self):
+        await self._general.backup()
+        await self._individuals.backup()
+
+    async def _loop(self):
+        while True:
+            await self.backup()
+            await asyncio.sleep(10)

+ 5 - 1
tests/conftest.py

@@ -36,11 +36,15 @@ def selenium(selenium: webdriver.Chrome) -> webdriver.Chrome:
 
 
 
 
 @pytest.fixture(autouse=True)
 @pytest.fixture(autouse=True)
-def reset_globals() -> Generator[None, None, None]:
+async def reset_globals() -> Generator[None, None, None]:
     for path in {'/'}.union(globals.page_routes.values()):
     for path in {'/'}.union(globals.page_routes.values()):
         globals.app.remove_route(path)
         globals.app.remove_route(path)
     globals.app.middleware_stack = None
     globals.app.middleware_stack = None
     importlib.reload(globals)
     importlib.reload(globals)
+    # importlib.reload(nicegui)
+    globals.app.storage.general.clear()
+    globals.app.storage._individuals.clear()
+    await globals.app.storage.backup()
     globals.index_client = Client(page('/'), shared=True).__enter__()
     globals.index_client = Client(page('/'), shared=True).__enter__()
     globals.app.get('/')(globals.index_client.build_response)
     globals.app.get('/')(globals.index_client.build_response)
 
 

+ 50 - 13
tests/test_storage.py

@@ -6,15 +6,15 @@ from nicegui import Client, app, ui
 from .screen import Screen
 from .screen import Screen
 
 
 
 
-def test_session_data_is_stored_in_the_browser(screen: Screen):
+def test_browser_data_is_stored_in_the_browser(screen: Screen):
     @ui.page('/')
     @ui.page('/')
     def page():
     def page():
-        app.storage.session['count'] = app.storage.session.get('count', 0) + 1
-        ui.label(app.storage.session['count'] or 'no session')
+        app.storage.browser['count'] = app.storage.browser.get('count', 0) + 1
+        ui.label(app.storage.browser['count'] or 'no session')
 
 
-    @app.get('/session')
-    def session():
-        return 'count = ' + str(app.storage.session['count'])
+    @app.get('/count')
+    def count():
+        return 'count = ' + str(app.storage.browser['count'])
 
 
     screen.open('/')
     screen.open('/')
     screen.should_contain('1')
     screen.should_contain('1')
@@ -22,16 +22,16 @@ def test_session_data_is_stored_in_the_browser(screen: Screen):
     screen.should_contain('2')
     screen.should_contain('2')
     screen.open('/')
     screen.open('/')
     screen.should_contain('3')
     screen.should_contain('3')
-    screen.open('/session')
-    screen.should_contain('count = 3')
+    screen.open('/count')
+    screen.should_contain('count = 3')  # also works with FastAPI endpoints
 
 
 
 
-def test_session_storage_supports_asyncio(screen: Screen):
+def test_browser_storage_supports_asyncio(screen: Screen):
     @ui.page('/')
     @ui.page('/')
     async def page():
     async def page():
-        app.storage.session['count'] = app.storage.session.get('count', 0) + 1
+        app.storage.browser['count'] = app.storage.browser.get('count', 0) + 1
         await asyncio.sleep(0.5)
         await asyncio.sleep(0.5)
-        ui.label(app.storage.session['count'] or 'no session')
+        ui.label(app.storage.browser['count'] or 'no session')
 
 
     screen.open('/')
     screen.open('/')
     screen.switch_to(1)
     screen.switch_to(1)
@@ -42,14 +42,51 @@ def test_session_storage_supports_asyncio(screen: Screen):
     screen.should_contain('3')
     screen.should_contain('3')
 
 
 
 
-def test_session_modifications_after_page_load(screen: Screen):
+def test_browser_storage_modifications_after_page_load_are_forbidden(screen: Screen):
     @ui.page('/')
     @ui.page('/')
     async def page(client: Client):
     async def page(client: Client):
         await client.connected()
         await client.connected()
         try:
         try:
-            app.storage.session['test'] = 'data'
+            app.storage.browser['test'] = 'data'
         except TypeError as e:
         except TypeError as e:
             ui.label(str(e))
             ui.label(str(e))
 
 
     screen.open('/')
     screen.open('/')
     screen.should_contain('response to the browser has already been build')
     screen.should_contain('response to the browser has already been build')
+
+
+def test_individual_storage_modifications(screen: Screen):
+    @ui.page('/')
+    async def page(client: Client, delayed: bool = False):
+        if delayed:
+            await client.connected()
+        app.storage.individual['count'] = app.storage.individual.get('count', 0) + 1
+        ui.label(app.storage.individual['count'] or 'no session')
+
+    screen.open('/')
+    screen.should_contain('1')
+    screen.open('/?delayed=True')
+    screen.should_contain('2')
+    screen.open('/')
+    screen.should_contain('3')
+
+
+def test_individual_and_general_storage_is_persisted(screen: Screen):
+    @ui.page('/')
+    def page():
+        app.storage.individual['count'] = app.storage.individual.get('count', 0) + 1
+        app.storage.general['count'] = app.storage.general.get('count', 0) + 1
+        ui.label(f'individual: {app.storage.individual["count"]}')
+        ui.label(f'general: {app.storage.general["count"]}')
+        ui.button('backup', on_click=app.storage.backup)
+
+    screen.open('/')
+    screen.open('/')
+    screen.open('/')
+    screen.should_contain('individual: 3')
+    screen.should_contain('general: 3')
+    screen.click('backup')
+    screen.selenium.delete_all_cookies()
+    screen.open('/')
+    screen.should_contain('individual: 1')
+    screen.should_contain('general: 4')