1
0
Эх сурвалжийг харах

fix image serializing - REF-1889 (#2550)

* fix image serializing

* If get_format_mimetype does not work, look up format in Image.MIME

Throw a warning if the image format does not have an associated MIME type and
ultimately fallback to image/png and let the browser figure it out.

* test_media: end to end serialization of PIL images

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Krys 1 жил өмнө
parent
commit
278183b526

+ 162 - 0
integration/test_media.py

@@ -0,0 +1,162 @@
+"""Integration tests for media components."""
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import AppHarness
+
+
+def MediaApp():
+    """Reflex app with generated images."""
+    import httpx
+    from PIL import Image
+
+    import reflex as rx
+
+    class State(rx.State):
+        def _blue(self, format=None) -> Image.Image:
+            img = Image.new("RGB", (200, 200), "blue")
+            if format is not None:
+                img.format = format  # type: ignore
+            return img
+
+        @rx.cached_var
+        def img_default(self) -> Image.Image:
+            return self._blue()
+
+        @rx.cached_var
+        def img_bmp(self) -> Image.Image:
+            return self._blue(format="BMP")
+
+        @rx.cached_var
+        def img_jpg(self) -> Image.Image:
+            return self._blue(format="JPEG")
+
+        @rx.cached_var
+        def img_png(self) -> Image.Image:
+            return self._blue(format="PNG")
+
+        @rx.cached_var
+        def img_gif(self) -> Image.Image:
+            return self._blue(format="GIF")
+
+        @rx.cached_var
+        def img_webp(self) -> Image.Image:
+            return self._blue(format="WEBP")
+
+        @rx.cached_var
+        def img_from_url(self) -> Image.Image:
+            img_url = "https://picsum.photos/id/1/200/300"
+            img_resp = httpx.get(img_url, follow_redirects=True)
+            return Image.open(img_resp)  # type: ignore
+
+    app = rx.App()
+
+    @app.add_page
+    def index():
+        return rx.vstack(
+            rx.input(
+                value=State.router.session.client_token,
+                read_only=True,
+                id="token",
+            ),
+            rx.image(src=State.img_default, alt="Default image", id="default"),
+            rx.image(src=State.img_bmp, alt="BMP image", id="bmp"),
+            rx.image(src=State.img_jpg, alt="JPG image", id="jpg"),
+            rx.image(src=State.img_png, alt="PNG image", id="png"),
+            rx.image(src=State.img_gif, alt="GIF image", id="gif"),
+            rx.image(src=State.img_webp, alt="WEBP image", id="webp"),
+            rx.image(src=State.img_from_url, alt="Image from URL", id="from_url"),
+        )
+
+
+@pytest.fixture()
+def media_app(tmp_path) -> Generator[AppHarness, None, None]:
+    """Start MediaApp app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path: pytest tmp_path fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path,
+        app_source=MediaApp,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.mark.asyncio
+async def test_media_app(media_app: AppHarness):
+    """Display images, ensure the data uri mime type is correct and images load.
+
+    Args:
+        media_app: harness for MediaApp app
+    """
+    assert media_app.app_instance is not None, "app is not running"
+    driver = media_app.frontend()
+
+    # wait for the backend connection to send the token
+    token_input = driver.find_element(By.ID, "token")
+    token = media_app.poll_for_value(token_input)
+    assert token
+
+    # check out the images
+    default_img = driver.find_element(By.ID, "default")
+    bmp_img = driver.find_element(By.ID, "bmp")
+    jpg_img = driver.find_element(By.ID, "jpg")
+    png_img = driver.find_element(By.ID, "png")
+    gif_img = driver.find_element(By.ID, "gif")
+    webp_img = driver.find_element(By.ID, "webp")
+    from_url_img = driver.find_element(By.ID, "from_url")
+
+    def check_image_loaded(img, check_width=" == 200", check_height=" == 200"):
+        return driver.execute_script(
+            "console.log(arguments); return arguments[1].complete "
+            '&& typeof arguments[1].naturalWidth != "undefined" '
+            f"&& arguments[1].naturalWidth {check_width} ",
+            '&& typeof arguments[1].naturalHeight != "undefined" '
+            f"&& arguments[1].naturalHeight {check_height} ",
+            img,
+        )
+
+    default_img_src = default_img.get_attribute("src")
+    assert default_img_src is not None
+    assert default_img_src.startswith("data:image/png;base64")
+    assert check_image_loaded(default_img)
+
+    bmp_img_src = bmp_img.get_attribute("src")
+    assert bmp_img_src is not None
+    assert bmp_img_src.startswith("data:image/bmp;base64")
+    assert check_image_loaded(bmp_img)
+
+    jpg_img_src = jpg_img.get_attribute("src")
+    assert jpg_img_src is not None
+    assert jpg_img_src.startswith("data:image/jpeg;base64")
+    assert check_image_loaded(jpg_img)
+
+    png_img_src = png_img.get_attribute("src")
+    assert png_img_src is not None
+    assert png_img_src.startswith("data:image/png;base64")
+    assert check_image_loaded(png_img)
+
+    gif_img_src = gif_img.get_attribute("src")
+    assert gif_img_src is not None
+    assert gif_img_src.startswith("data:image/gif;base64")
+    assert check_image_loaded(gif_img)
+
+    webp_img_src = webp_img.get_attribute("src")
+    assert webp_img_src is not None
+    assert webp_img_src.startswith("data:image/webp;base64")
+    assert check_image_loaded(webp_img)
+
+    from_url_img_src = from_url_img.get_attribute("src")
+    assert from_url_img_src is not None
+    assert from_url_img_src.startswith("data:image/jpeg;base64")
+    assert check_image_loaded(
+        from_url_img,
+        check_width=" == 200",
+        check_height=" == 300",
+    )

+ 18 - 2
reflex/utils/serializers.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import json
 import types as builtin_types
+import warnings
 from datetime import date, datetime, time, timedelta
 from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union, get_type_hints
 
@@ -303,6 +304,7 @@ try:
     import base64
     import io
 
+    from PIL.Image import MIME
     from PIL.Image import Image as Img
 
     @serializer
@@ -316,10 +318,24 @@ try:
             The serialized image.
         """
         buff = io.BytesIO()
-        image.save(buff, format=getattr(image, "format", None) or "PNG")
+        image_format = getattr(image, "format", None) or "PNG"
+        image.save(buff, format=image_format)
         image_bytes = buff.getvalue()
         base64_image = base64.b64encode(image_bytes).decode("utf-8")
-        mime_type = getattr(image, "get_format_mimetype", lambda: "image/png")()
+        try:
+            # Newer method to get the mime type, but does not always work.
+            mime_type = image.get_format_mimetype()  # type: ignore
+        except AttributeError:
+            try:
+                # Fallback method
+                mime_type = MIME[image_format]
+            except KeyError:
+                # Unknown mime_type: warn and return image/png and hope the browser can sort it out.
+                warnings.warn(
+                    f"Unknown mime type for {image} {image_format}. Defaulting to image/png"
+                )
+                mime_type = "image/png"
+
         return f"data:{mime_type};base64,{base64_image}"
 
 except ImportError: