123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- import contextvars
- import json
- import uuid
- from collections.abc import MutableMapping
- from pathlib import Path
- from typing import Any, Dict, Iterator, Optional, Union
- import aiofiles
- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
- from starlette.requests import Request
- from starlette.responses import Response
- from . import background_tasks, globals, observables # pylint: disable=redefined-builtin
- request_contextvar: contextvars.ContextVar[Optional[Request]] = contextvars.ContextVar('request_var', default=None)
- class ReadOnlyDict(MutableMapping):
- def __init__(self, data: Dict[Any, Any], write_error_message: str = 'Read-only dict') -> None:
- self._data: Dict[Any, Any] = data
- self._write_error_message: str = write_error_message
- def __getitem__(self, item: Any) -> Any:
- return self._data[item]
- def __setitem__(self, key: Any, value: Any) -> None:
- raise TypeError(self._write_error_message)
- def __delitem__(self, key: Any) -> None:
- raise TypeError(self._write_error_message)
- def __iter__(self) -> Iterator:
- return iter(self._data)
- def __len__(self) -> int:
- return len(self._data)
- class PersistentDict(observables.ObservableDict):
- def __init__(self, filepath: Path) -> None:
- self.filepath = filepath
- data = json.loads(filepath.read_text()) if filepath.exists() else {}
- super().__init__(data, on_change=self.backup)
- def backup(self) -> None:
- if not self.filepath.exists():
- if not self:
- return
- self.filepath.parent.mkdir(exist_ok=True)
- async def backup() -> None:
- async with aiofiles.open(self.filepath, 'w') as f:
- await f.write(json.dumps(self))
- if globals.loop:
- background_tasks.create_lazy(backup(), name=self.filepath.stem)
- else:
- globals.app.on_startup(backup())
- class RequestTrackingMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
- request_contextvar.set(request)
- if 'id' not in request.session:
- request.session['id'] = str(uuid.uuid4())
- request.state.responded = False
- response = await call_next(request)
- request.state.responded = True
- return response
- class Storage:
- def __init__(self) -> None:
- self._general = PersistentDict(globals.storage_path / 'storage_general.json')
- self._users: Dict[str, PersistentDict] = {}
- @property
- def browser(self) -> Union[ReadOnlyDict, Dict]:
- """Small storage that is saved directly within the user's browser (encrypted cookie).
- The data is shared between all browser tabs and can only be modified before the initial request has been submitted.
- Therefore it is normally better to use `app.storage.user` instead,
- which can be modified anytime, reduces overall payload, improves security and has larger storage capacity.
- """
- request: Optional[Request] = request_contextvar.get()
- if request is None:
- if globals.get_client() == globals.index_client:
- raise RuntimeError('app.storage.browser can only be used with page builder functions '
- '(https://nicegui.io/documentation/page)')
- raise RuntimeError('app.storage.browser needs a storage_secret passed in ui.run()')
- if request.state.responded:
- return ReadOnlyDict(
- request.session,
- 'the response to the browser has already been built, so modifications cannot be sent back anymore'
- )
- return request.session
- @property
- def user(self) -> Dict:
- """Individual user storage that is persisted on the server (where NiceGUI is executed).
- The data is stored in a file on the server.
- It is shared between all browser tabs by identifying the user via session cookie ID.
- """
- request: Optional[Request] = request_contextvar.get()
- if request is None:
- if globals.get_client() == globals.index_client:
- raise RuntimeError('app.storage.user can only be used with page builder functions '
- '(https://nicegui.io/documentation/page)')
- raise RuntimeError('app.storage.user needs a storage_secret passed in ui.run()')
- session_id = request.session['id']
- if session_id not in self._users:
- self._users[session_id] = PersistentDict(globals.storage_path / f'storage_user_{session_id}.json')
- return self._users[session_id]
- @property
- def general(self) -> Dict:
- """General storage shared between all users that is persisted on the server (where NiceGUI is executed)."""
- return self._general
- def clear(self) -> None:
- """Clears all storage."""
- self._general.clear()
- self._users.clear()
- for filepath in globals.storage_path.glob('storage_*.json'):
- filepath.unlink()
|