1
0
Эх сурвалжийг харах

Merge pull request #92 from zauberzeug/auto-context

Auto context
Rodja Trappe 2 жил өмнө
parent
commit
b00343ed4d

+ 5 - 4
nicegui/elements/group.py

@@ -12,13 +12,14 @@ from .element import Element
 class Group(Element):
 
     def __enter__(self):
-        globals.view_stack.append(self.view)
+        self._child_count_on_enter = len(self.view)
+        globals.get_view_stack().append(self.view)
         return self
 
     def __exit__(self, *_):
-        globals.view_stack.pop()
-        if len(globals.view_stack) <= 1:
-            self.update()  # NOTE: update when we are back on top of the stack (only the first page is in view stack)
+        globals.get_view_stack().pop()
+        if self._child_count_on_enter != len(self.view):
+            self.update()
 
     def tight(self) -> Group:
         return self.classes(replace='').style(replace='')

+ 2 - 2
nicegui/elements/scene.py

@@ -108,14 +108,14 @@ class Scene(Element):
         super().__init__(SceneView(width=width, height=height, on_click=on_click))
 
     def __enter__(self):
-        globals.view_stack.append(self.view)
+        globals.get_view_stack().append(self.view)
         scene = self.view.objects.get('scene', SceneObject(self.view, self.page))
         Object3D.stack.clear()
         Object3D.stack.append(scene)
         return self
 
     def __exit__(self, *_):
-        globals.view_stack.pop()
+        globals.get_view_stack().pop()
 
     def move_camera(self,
                     x: Optional[float] = None,

+ 7 - 3
nicegui/events.py

@@ -230,12 +230,16 @@ def handle_event(handler: Optional[Callable], arguments: EventArguments) -> Opti
         if handler is None:
             return False
         no_arguments = not signature(handler).parameters
-        result = handler() if no_arguments else handler(arguments)
+        with globals.within_view(arguments.sender.parent_view):
+            result = handler() if no_arguments else handler(arguments)
         if is_coroutine(handler):
+            async def wait_for_result():
+                with globals.within_view(arguments.sender.parent_view):
+                    await result
             if globals.loop and globals.loop.is_running():
-                create_task(result, name=str(handler))
+                create_task(wait_for_result(), name=str(handler))
             else:
-                on_startup(None, result)
+                on_startup(None, wait_for_result())
         return False
     except Exception:
         traceback.print_exc()

+ 32 - 2
nicegui/globals.py

@@ -2,7 +2,8 @@ from __future__ import annotations
 
 import asyncio
 import logging
-from typing import Awaitable, Callable, Dict, List, Optional, Union
+from contextlib import contextmanager
+from typing import Awaitable, Callable, Dict, Generator, List, Optional, Union
 
 import justpy as jp
 from starlette.applications import Starlette
@@ -10,13 +11,14 @@ from uvicorn import Server
 
 from .config import Config
 from .page_builder import PageBuilder
+from .task_logger import create_task
 
 app: Starlette
 config: Optional[Config] = None
 server: Optional[Server] = None
 loop: Optional[asyncio.AbstractEventLoop] = None
 page_builders: Dict[str, 'PageBuilder'] = {}
-view_stack: List[jp.HTMLBaseComponent] = []
+view_stacks: Dict[List[jp.HTMLBaseComponent]] = {}
 tasks: List[asyncio.tasks.Task] = []
 log: logging.Logger = logging.getLogger('nicegui')
 connect_handlers: List[Union[Callable, Awaitable]] = []
@@ -30,3 +32,31 @@ def find_route(function: Callable) -> str:
     if not routes:
         raise ValueError(f'Invalid page function {function}')
     return routes[0]
+
+
+def get_task_id() -> int:
+    return id(asyncio.current_task()) if loop and loop.is_running() else 0
+
+
+def get_view_stack() -> List[jp.HTMLBaseComponent]:
+    task_id = get_task_id()
+    if task_id not in view_stacks:
+        view_stacks[task_id] = []
+    return view_stacks[task_id]
+
+
+def prune_view_stack() -> None:
+    task_id = get_task_id()
+    if not view_stacks[task_id]:
+        del view_stacks[task_id]
+
+
+@contextmanager
+def within_view(view: jp.HTMLBaseComponent) -> Generator[None, None, None]:
+    child_count = len(view)
+    get_view_stack().append(view)
+    yield
+    get_view_stack().pop()
+    prune_view_stack()
+    if len(view) != child_count:
+        create_task(view.update())

+ 10 - 9
nicegui/page.py

@@ -168,9 +168,8 @@ def page(self,
                 on_disconnect=on_disconnect,
                 shared=shared,
             )
-            globals.view_stack.append(page.view)
-            await func() if is_coroutine(func) else func()
-            globals.view_stack.pop()
+            with globals.within_view(page.view):
+                await func() if is_coroutine(func) else func()
             return page
         builder = PageBuilder(decorated, shared)
         if globals.server:
@@ -181,13 +180,14 @@ def page(self,
 
 
 def find_parent_view() -> jp.HTMLBaseComponent:
-    if not globals.view_stack:
+    view_stack = globals.get_view_stack()
+    if not view_stack:
         if globals.loop and globals.loop.is_running():
             raise RuntimeError('cannot find parent view, view stack is empty')
         page = Page(shared=True)
-        globals.view_stack.append(page.view)
+        view_stack.append(page.view)
         jp.Route('/', page._route_function)
-    return globals.view_stack[-1]
+    return view_stack[-1]
 
 
 def error404() -> jp.QuasarPage:
@@ -204,14 +204,15 @@ def error404() -> jp.QuasarPage:
 
 
 def init_auto_index_page() -> None:
-    if not globals.view_stack:
+    view_stack = globals.view_stacks.get(0)
+    if not view_stack:
         return  # there is no auto-index page on the view stack
-    page: Page = globals.view_stack.pop().pages[0]
+    page: Page = view_stack.pop().pages[0]
     page.title = globals.config.title
     page.favicon = globals.config.favicon
     page.dark = globals.config.dark
     page.view.classes = globals.config.main_page_classes
-    assert len(globals.view_stack) == 0
+    assert len(view_stack) == 0
 
 
 def create_page_routes() -> None:

+ 6 - 4
nicegui/timer.py

@@ -7,6 +7,7 @@ from typing import Callable, List
 from . import globals
 from .binding import BindableProperty
 from .helpers import is_coroutine
+from .page import find_parent_view
 from .task_logger import create_task
 
 NamedCoroutine = namedtuple('NamedCoroutine', ['name', 'coro'])
@@ -32,13 +33,14 @@ class Timer:
 
         self.active = active
         self.interval = interval
+        self.parent_view = find_parent_view()
 
         async def do_callback():
             try:
-                if is_coroutine(callback):
-                    return await callback()
-                else:
-                    return callback()
+                with globals.within_view(self.parent_view):
+                    result = callback()
+                    if is_coroutine(callback):
+                        await result
             except Exception:
                 traceback.print_exc()
 

+ 46 - 0
tests/test_auto_context.py

@@ -0,0 +1,46 @@
+import asyncio
+
+from nicegui import ui
+
+from .user import User
+
+
+def test_adding_element_to_index_page(user: User):
+    ui.button('add label', on_click=lambda: ui.label('added'))
+
+    user.open('/')
+    user.click('add label')
+    user.should_see('added')
+
+
+def test_adding_element_to_private_page(user: User):
+    @ui.page('/')
+    def page():
+        ui.button('add label', on_click=lambda: ui.label('added'))
+
+    user.open('/')
+    user.click('add label')
+    user.should_see('added')
+
+
+def test_adding_elements_with_async_await(user: User):
+    async def add_a():
+        await asyncio.sleep(0.1)
+        ui.label('A')
+
+    async def add_b():
+        await asyncio.sleep(0.1)
+        ui.label('B')
+
+    with ui.card():
+        ui.timer(1.0, add_a, once=True)
+    with ui.card():
+        ui.timer(1.1, add_b, once=True)
+
+    user.open('/')
+    assert '''
+card
+  A
+card
+  B
+''' in user.page(), f'{user.page()} should show cards with "A" and "B"'

+ 31 - 2
tests/test_pages.py

@@ -5,13 +5,42 @@ from nicegui import ui
 from .user import User
 
 
-def test_title(user: User):
+def test_page(user: User):
+    @ui.page('/')
+    def page():
+        ui.label('Hello, world!')
+
+    user.open('/')
+    user.should_see('NiceGUI')
+    user.should_see('Hello, world!')
+
+
+def test_shared_page(user: User):
+    @ui.page('/', shared=True)
+    def page():
+        ui.label('Hello, world!')
+
+    user.open('/')
+    user.should_see('NiceGUI')
+    user.should_see('Hello, world!')
+
+
+def test_auto_index_page(user: User):
+    ui.label('Hello, world!')
+
+    user.open('/')
+    user.should_see('NiceGUI')
+    user.should_see('Hello, world!')
+
+
+def test_custom_title(user: User):
     @ui.page('/', title='My Custom Title')
     def page():
-        ui.label('some content')
+        ui.label('Hello, world!')
 
     user.open('/')
     user.should_see('My Custom Title')
+    user.should_see('Hello, world!')
 
 
 def test_route_with_custom_path(user: User):

+ 17 - 0
tests/test_user.py

@@ -11,10 +11,25 @@ def test_rendering_page(user: User):
         ui.label('1')
         ui.label('2')
         ui.label('3')
+    with ui.card():
+        ui.label('some text')
 
     user.open('/')
     assert user.page() == '''Title: NiceGUI
 
+test label
+row
+  test input: some placeholder
+column
+  1
+  2
+  3
+card
+  some text
+'''
+
+    assert user.page(with_extras=True) == '''Title: NiceGUI
+
 test label
 row [class: items-start positive]
   test input: some placeholder [class: no-wrap items-start standard labeled]
@@ -22,4 +37,6 @@ column [class: items-start]
   1
   2
   3
+card [class: items-start q-pa-md]
+  some text
 '''

+ 9 - 8
tests/user.py

@@ -7,7 +7,7 @@ from selenium.common.exceptions import NoSuchElementException
 from selenium.webdriver.remote.webelement import WebElement
 
 PORT = 3392
-IGNORED_CLASSES = ['row', 'column', 'q-field', 'q-field__label', 'q-input']
+IGNORED_CLASSES = ['row', 'column', 'q-card', 'q-field', 'q-field__label', 'q-input']
 
 
 class User():
@@ -54,10 +54,11 @@ class User():
         except NoSuchElementException:
             raise AssertionError(f'Could not find "{text}" on:\n{self.page()}')
 
-    def page(self) -> str:
-        return f'Title: {self.selenium.title}\n\n' + self.content(self.selenium.find_element_by_tag_name('body'))
+    def page(self, with_extras: bool = False) -> str:
+        return f'Title: {self.selenium.title}\n\n' + \
+            self.content(self.selenium.find_element_by_tag_name('body'), with_extras=with_extras)
 
-    def content(self, element: WebElement, indent: str = '') -> str:
+    def content(self, element: WebElement, indent: str = '', with_extras: bool = False) -> str:
         content = ''
         classes: list[str] = []
         for child in element.find_elements_by_xpath('./*'):
@@ -70,8 +71,8 @@ class User():
                 content += f'{indent}{child.text}'
             classes = child.get_attribute('class').strip().split()
             if classes:
-                if classes[0] in ['row', 'column']:
-                    content += classes[0]
+                if classes[0] in ['row', 'column', 'q-card']:
+                    content += classes[0].removeprefix('q-')
                     is_element = True
                     is_group = True
                 if classes[0] == 'q-field':
@@ -87,12 +88,12 @@ class User():
                 [classes.remove(c) for c in IGNORED_CLASSES if c in classes]
                 for i, c in enumerate(classes):
                     classes[i] = c.removeprefix('q-field--')
-                if is_element:
+                if is_element and with_extras:
                     content += f' [class: {" ".join(classes)}]'
             if is_element:
                 content += '\n'
             if render_children:
-                content += self.content(child, indent + ('  ' if is_group else ''))
+                content += self.content(child, indent + ('  ' if is_group else ''), with_extras)
         return content
 
     def get_tags(self, name: str) -> list[WebElement]: