Browse Source

first proof of concept for global session access

Rodja Trappe 2 years ago
parent
commit
1360ef5248
4 changed files with 34 additions and 1 deletions
  1. 6 0
      nicegui/globals.py
  2. 2 0
      nicegui/page.py
  3. 9 1
      nicegui/run.py
  4. 17 0
      tests/test_session.py

+ 6 - 0
nicegui/globals.py

@@ -5,6 +5,7 @@ from contextlib import contextmanager
 from enum import Enum
 from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union
 
+from fastapi import Request
 from socketio import AsyncServer
 from uvicorn import Server
 
@@ -45,6 +46,7 @@ socket_io_js_extra_headers: Dict = {}
 
 _socket_id: Optional[str] = None
 slot_stacks: Dict[int, List['Slot']] = {}
+requests: Dict[str, Request] = {}
 clients: Dict[str, 'Client'] = {}
 index_client: 'Client'
 
@@ -85,6 +87,10 @@ def get_client() -> 'Client':
     return get_slot().parent.client
 
 
+def get_request() -> Request:
+    return requests[_socket_id]
+
+
 @contextmanager
 def socket_id(id: str) -> None:
     global _socket_id

+ 2 - 0
nicegui/page.py

@@ -68,6 +68,7 @@ class page:
 
         async def decorated(*dec_args, **dec_kwargs) -> Response:
             request = dec_kwargs['request']
+            globals.requests[globals._socket_id] = request
             # NOTE cleaning up the keyword args so the signature is consistent with "func" again
             dec_kwargs = {k: v for k, v in dec_kwargs.items() if k in parameters_of_decorated_func}
             with Client(self) as client:
@@ -87,6 +88,7 @@ class page:
                 result = task.result() if task.done() else None
             if isinstance(result, Response):  # NOTE if setup returns a response, we don't need to render the page
                 return result
+            del globals.requests[globals._socket_id]
             return client.build_response(request)
 
         parameters = [p for p in inspect.signature(func).parameters.values() if p.name != 'client']

+ 9 - 1
nicegui/run.py

@@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple
 
 import __main__
 import uvicorn
+from starlette.middleware.sessions import SessionMiddleware
 from uvicorn.main import STARTUP_FAILURE
 from uvicorn.supervisors import ChangeReload, Multiprocess
 
@@ -13,6 +14,13 @@ from . import globals, helpers, native_mode
 from .language import Language
 
 
+class Server(uvicorn.Server):
+
+    def run(self, sockets: List[Any] = None) -> None:
+        globals.app.add_middleware(SessionMiddleware, secret_key='some_random_string')  # TODO real random string
+        super().run(sockets=sockets)
+
+
 def run(*,
         host: Optional[str] = None,
         port: int = 8080,
@@ -115,7 +123,7 @@ def run(*,
         log_level=uvicorn_logging_level,
         **kwargs,
     )
-    globals.server = uvicorn.Server(config=config)
+    globals.server = Server(config=config)
 
     if (reload or config.workers > 1) and not isinstance(config.app, str):
         logging.warning('You must pass the application as an import string to enable "reload" or "workers".')

+ 17 - 0
tests/test_session.py

@@ -0,0 +1,17 @@
+from nicegui import globals, ui
+
+from .screen import Screen
+
+
+def test_session_data_is_stored_in_the_browser(screen: Screen):
+    @ui.page('/')
+    def page():
+        globals.get_request().session['count'] = globals.get_request().session.get('count', 0) + 1
+        ui.label(globals.get_request().session['count'])
+
+    screen.open('/')
+    screen.should_contain('1')
+    screen.open('/')
+    screen.should_contain('2')
+    screen.open('/')
+    screen.should_contain('3')