Bladeren bron

Also handle filtering endpoints created at runtime

David Kincaid 1 jaar geleden
bovenliggende
commit
186f9771b1
4 gewijzigde bestanden met toevoegingen van 38 en 20 verwijderingen
  1. 14 1
      nicegui/globals.py
  2. 6 0
      nicegui/page.py
  3. 10 11
      nicegui/run.py
  4. 8 8
      tests/test_endpoint_docs.py

+ 14 - 1
nicegui/globals.py

@@ -4,7 +4,19 @@ import logging
 from contextlib import contextmanager
 from contextlib import contextmanager
 from enum import Enum
 from enum import Enum
 from pathlib import Path
 from pathlib import Path
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterator, List, Optional, Set, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterator,
+    List,
+    Literal,
+    Optional,
+    Set,
+    Union,
+)
 
 
 from socketio import AsyncServer
 from socketio import AsyncServer
 from uvicorn import Server
 from uvicorn import Server
@@ -45,6 +57,7 @@ binding_refresh_interval: float
 tailwind: bool
 tailwind: bool
 air: Optional['Air'] = None
 air: Optional['Air'] = None
 socket_io_js_extra_headers: Dict = {}
 socket_io_js_extra_headers: Dict = {}
+endpoint_documentation: Literal['none', 'internal', 'page', 'all'] = 'none'
 
 
 _socket_id: Optional[str] = None
 _socket_id: Optional[str] = None
 slot_stacks: Dict[int, List['Slot']] = {}
 slot_stacks: Dict[int, List['Slot']] = {}

+ 6 - 0
nicegui/page.py

@@ -107,6 +107,12 @@ class page:
             parameters.insert(0, request)
             parameters.insert(0, request)
         decorated.__signature__ = inspect.Signature(parameters)
         decorated.__signature__ = inspect.Signature(parameters)
 
 
+        show_in_docs = False
+        if globals.endpoint_documentation in ['page', 'all']:
+            show_in_docs = True
+
+        self.kwargs['include_in_schema'] = show_in_docs
+
         self.api_router.get(self._path, **self.kwargs)(decorated)
         self.api_router.get(self._path, **self.kwargs)(decorated)
         globals.page_routes[func] = self.path
         globals.page_routes[func] = self.path
         return func
         return func

+ 10 - 11
nicegui/run.py

@@ -11,9 +11,8 @@ import uvicorn
 from uvicorn.main import STARTUP_FAILURE
 from uvicorn.main import STARTUP_FAILURE
 from uvicorn.supervisors import ChangeReload, Multiprocess
 from uvicorn.supervisors import ChangeReload, Multiprocess
 
 
-from . import globals, helpers
+from . import globals, helpers, native_mode
 from . import native as native_module
 from . import native as native_module
-from . import native_mode
 from .air import Air
 from .air import Air
 from .language import Language
 from .language import Language
 
 
@@ -53,7 +52,7 @@ def run(*,
         uvicorn_reload_includes: str = '*.py',
         uvicorn_reload_includes: str = '*.py',
         uvicorn_reload_excludes: str = '.*, .py[cod], .sw.*, ~*',
         uvicorn_reload_excludes: str = '.*, .py[cod], .sw.*, ~*',
         tailwind: bool = True,
         tailwind: bool = True,
-        endpoint_documentation: str = '',
+        endpoint_documentation: Literal['none', 'internal', 'page', 'all'] = 'none',
         storage_secret: Optional[str] = None,
         storage_secret: Optional[str] = None,
         **kwargs: Any,
         **kwargs: Any,
         ) -> None:
         ) -> None:
@@ -80,7 +79,7 @@ def run(*,
     :param uvicorn_reload_includes: string with comma-separated list of glob-patterns which trigger reload on modification (default: `'.py'`)
     :param uvicorn_reload_includes: string with comma-separated list of glob-patterns which trigger reload on modification (default: `'.py'`)
     :param uvicorn_reload_excludes: string with comma-separated list of glob-patterns which should be ignored for reload (default: `'.*, .py[cod], .sw.*, ~*'`)
     :param uvicorn_reload_excludes: string with comma-separated list of glob-patterns which should be ignored for reload (default: `'.*, .py[cod], .sw.*, ~*'`)
     :param tailwind: whether to use Tailwind (experimental, default: `True`)
     :param tailwind: whether to use Tailwind (experimental, default: `True`)
-    :param endpoint_documentation: control what endpoints appear in the autogenerated OpenAPI docs (default: '', options: 'all internal page')
+    :param endpoint_documentation: control what endpoints appear in the autogenerated OpenAPI docs (default: 'none', options: 'none', 'internal', 'page', 'all')
     :param storage_secret: secret key for browser based storage (default: `None`, a value is required to enable ui.storage.individual and ui.storage.browser)
     :param storage_secret: secret key for browser based storage (default: `None`, a value is required to enable ui.storage.individual and ui.storage.browser)
     :param kwargs: additional keyword arguments are passed to `uvicorn.run`    
     :param kwargs: additional keyword arguments are passed to `uvicorn.run`    
     '''
     '''
@@ -93,20 +92,20 @@ def run(*,
     globals.language = language
     globals.language = language
     globals.binding_refresh_interval = binding_refresh_interval
     globals.binding_refresh_interval = binding_refresh_interval
     globals.tailwind = tailwind
     globals.tailwind = tailwind
-
-    if 'all' in endpoint_documentation:
-        endpoint_documentation = 'internal page'  # any additional documentation groups need to be added here
+    globals.endpoint_documentation = endpoint_documentation
 
 
     # routes are already created by this point, so we have to iterate through and fix them
     # routes are already created by this point, so we have to iterate through and fix them
     for route in globals.app.routes:
     for route in globals.app.routes:
         if route.path.startswith('/_nicegui'):
         if route.path.startswith('/_nicegui'):
             if hasattr(route, 'methods'):
             if hasattr(route, 'methods'):
-                if 'internal' not in endpoint_documentation:
-                    route.include_in_schema = False
+                route.include_in_schema = False
+                if endpoint_documentation in ['internal', 'all']:
+                    route.include_in_schema = True
 
 
         if route.name == 'decorated':
         if route.name == 'decorated':
-            if 'page' not in endpoint_documentation:
-                route.include_in_schema = False
+            route.include_in_schema = False
+            if endpoint_documentation in ['page', 'all']:
+                route.include_in_schema = True
 
 
     if on_air:
     if on_air:
         globals.air = Air('' if on_air is True else on_air)
         globals.air = Air('' if on_air is True else on_air)

+ 8 - 8
tests/test_endpoint_docs.py

@@ -1,16 +1,16 @@
 import requests
 import requests
 
 
-from nicegui import __version__
+import nicegui
+from nicegui import __version__, ui
 
 
 from .screen import PORT, Screen
 from .screen import PORT, Screen
 
 
 
 
 def test_endpoint_documentation_default(screen: Screen):
 def test_endpoint_documentation_default(screen: Screen):
-    screen.ui_run_kwargs['endpoint_documentation'] = ''
     screen.open('/')
     screen.open('/')
 
 
     response = requests.get(f'http://localhost:{PORT}/openapi.json')
     response = requests.get(f'http://localhost:{PORT}/openapi.json')
-    assert list(response.json()['paths']) == []
+    assert set(response.json()['paths']) == set()
 
 
 
 
 def test_endpoint_documentation_page_only(screen: Screen):
 def test_endpoint_documentation_page_only(screen: Screen):
@@ -18,7 +18,7 @@ def test_endpoint_documentation_page_only(screen: Screen):
     screen.open('/')
     screen.open('/')
 
 
     response = requests.get(f'http://localhost:{PORT}/openapi.json')
     response = requests.get(f'http://localhost:{PORT}/openapi.json')
-    assert list(response.json()['paths']) == ['/']
+    assert set(response.json()['paths']) == {'/'}
 
 
 
 
 def test_endpoint_documentation_internal_only(screen: Screen):
 def test_endpoint_documentation_internal_only(screen: Screen):
@@ -26,10 +26,10 @@ def test_endpoint_documentation_internal_only(screen: Screen):
     screen.open('/')
     screen.open('/')
 
 
     response = requests.get(f'http://localhost:{PORT}/openapi.json')
     response = requests.get(f'http://localhost:{PORT}/openapi.json')
-    assert list(response.json()['paths']) == [
+    assert set(response.json()['paths']) == {
         f'/_nicegui/{__version__}/libraries/{{key}}',
         f'/_nicegui/{__version__}/libraries/{{key}}',
         f'/_nicegui/{__version__}/components/{{key}}',
         f'/_nicegui/{__version__}/components/{{key}}',
-    ]
+    }
 
 
 
 
 def test_endpoint_documentation_all(screen: Screen):
 def test_endpoint_documentation_all(screen: Screen):
@@ -37,8 +37,8 @@ def test_endpoint_documentation_all(screen: Screen):
     screen.open('/')
     screen.open('/')
 
 
     response = requests.get(f'http://localhost:{PORT}/openapi.json')
     response = requests.get(f'http://localhost:{PORT}/openapi.json')
-    assert list(response.json()['paths']) == [
+    assert set(response.json()['paths']) == {
         '/',
         '/',
         f'/_nicegui/{__version__}/libraries/{{key}}',
         f'/_nicegui/{__version__}/libraries/{{key}}',
         f'/_nicegui/{__version__}/components/{{key}}',
         f'/_nicegui/{__version__}/components/{{key}}',
-    ]
+    }