Quellcode durchsuchen

add support for data urls

Rodja Trappe vor 2 Jahren
Ursprung
Commit
a0a8a823ae
2 geänderte Dateien mit 34 neuen und 7 gelöschten Zeilen
  1. 17 4
      nicegui/favicon.py
  2. 17 3
      tests/test_favicon.py

+ 17 - 4
nicegui/favicon.py

@@ -1,9 +1,10 @@
+import base64
+import io
 import urllib.parse
 from pathlib import Path
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
 
-from fastapi import Response
-from fastapi.responses import FileResponse
+from fastapi.responses import FileResponse, Response, StreamingResponse
 
 from . import __version__, globals
 
@@ -35,7 +36,13 @@ def get_favicon_url(page: 'page', prefix: str) -> str:
 
 
 def get_favicon_response() -> Response:
-    return Response(char_to_svg(globals.favicon), media_type='image/svg+xml')
+    if is_svg(globals.favicon):
+        return Response(globals.favicon, media_type='image/svg+xml')
+    elif is_data_url(globals.favicon):
+        media_type, bytes = data_url_to_bytes(globals.favicon)
+        return StreamingResponse(io.BytesIO(bytes), media_type=media_type)
+    elif is_char(globals.favicon):
+        return Response(char_to_svg(globals.favicon), media_type='image/svg+xml')
 
 
 def is_remote_url(favicon: str) -> bool:
@@ -76,3 +83,9 @@ def char_to_svg(char: str) -> str:
 def svg_to_data_url(svg: str) -> str:
     svg_urlencoded = urllib.parse.quote(svg)
     return f'data:image/svg+xml,{svg_urlencoded}'
+
+
+def data_url_to_bytes(data_url: str) -> Tuple[str, bytes]:
+    media_type, base64_image = data_url.split(",", 1)
+    media_type = media_type.split(":")[1].split(";")[0]
+    return media_type, base64.b64decode(base64_image)

+ 17 - 3
tests/test_favicon.py

@@ -18,13 +18,17 @@ def assert_favicon_url_starts_with(screen: Screen, content: str):
     assert icon_link['href'].startswith(content)
 
 
-def assert_favicon(content: Union[Path, str], url_path: str = '/favicon.ico'):
+def assert_favicon(content: Union[Path, str, bytes], url_path: str = '/favicon.ico'):
     response = requests.get(f'http://localhost:{PORT}{url_path}')
     assert response.status_code == 200
     if isinstance(content, Path):
         assert content.read_bytes() == response.content
-    else:
+    elif isinstance(content, str):
         assert content == response.text
+    elif isinstance(content, bytes):
+        assert content == response.content
+    else:
+        raise TypeError(f'Unexpected type: {type(content)}')
 
 
 def test_default(screen: Screen):
@@ -40,10 +44,20 @@ def test_emoji(screen: Screen):
     screen.ui_run_kwargs['favicon'] = '👋'
     screen.open('/')
     assert_favicon_url_starts_with(screen, ''
+    screen.ui_run_kwargs['favicon'] = icon
+    screen.open('/')
+    assert_favicon_url_starts_with(screen, 'data:image/png;base64')
+    _, bytes = favicon.data_url_to_bytes(icon)
+    assert_favicon(bytes)
+
+
 def test_custom_file(screen: Screen):
     ui.label('Hello, world')