瀏覽代碼

DataFrame Serializer fix (#2281)

Elijah Ahianyo 1 年之前
父節點
當前提交
9629b59617

+ 1 - 37
reflex/components/datadisplay/datatable.py

@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Union
 from reflex.components.component import Component
 from reflex.components.tags import Tag
 from reflex.utils import imports, types
-from reflex.utils.serializers import serialize, serializer
+from reflex.utils.serializers import serialize
 from reflex.vars import BaseVar, ComputedVar, Var
 
 
@@ -129,39 +129,3 @@ class DataTable(Gridjs):
 
         # Render the table.
         return super()._render()
-
-
-try:
-    from pandas import DataFrame
-
-    def format_dataframe_values(df: DataFrame) -> List[List[Any]]:
-        """Format dataframe values to a list of lists.
-
-        Args:
-            df: The dataframe to format.
-
-        Returns:
-            The dataframe as a list of lists.
-        """
-        return [
-            [str(d) if isinstance(d, (list, tuple)) else d for d in data]
-            for data in list(df.values.tolist())
-        ]
-
-    @serializer
-    def serialize_dataframe(df: DataFrame) -> dict:
-        """Serialize a pandas dataframe.
-
-        Args:
-            df: The dataframe to serialize.
-
-        Returns:
-            The serialized dataframe.
-        """
-        return {
-            "columns": df.columns.tolist(),
-            "data": format_dataframe_values(df),
-        }
-
-except ImportError:
-    pass

+ 1 - 11
reflex/components/datadisplay/datatable.pyi

@@ -11,7 +11,7 @@ from typing import Any, Dict, List, Union
 from reflex.components.component import Component
 from reflex.components.tags import Tag
 from reflex.utils import imports, types
-from reflex.utils.serializers import serialize, serializer
+from reflex.utils.serializers import serialize
 from reflex.vars import BaseVar, ComputedVar, Var
 
 class Gridjs(Component):
@@ -183,13 +183,3 @@ class DataTable(Gridjs):
             ValueError: If a pandas dataframe is passed in and columns are also provided.
         """
         ...
-
-try:
-    from pandas import DataFrame
-
-    def format_dataframe_values(df: DataFrame) -> List[List[Any]]: ...
-    @serializer
-    def serialize_dataframe(df: DataFrame) -> dict: ...
-
-except ImportError:
-    pass

+ 0 - 22
reflex/components/graphing/plotly.py

@@ -1,10 +1,8 @@
 """Component for displaying a plotly graph."""
 
-import json
 from typing import Any, Dict, List
 
 from reflex.components.component import NoSSRComponent
-from reflex.utils.serializers import serializer
 from reflex.vars import Var
 
 try:
@@ -42,23 +40,3 @@ class Plotly(PlotlyLib):
 
     # If true, the graph will resize when the window is resized.
     use_resize_handler: Var[bool]
-
-
-try:
-    from plotly.graph_objects import Figure
-    from plotly.io import to_json
-
-    @serializer
-    def serialize_figure(figure: Figure) -> list:
-        """Serialize a plotly figure.
-
-        Args:
-            figure: The figure to serialize.
-
-        Returns:
-            The serialized figure.
-        """
-        return json.loads(str(to_json(figure)))["data"]
-
-except ImportError:
-    pass

+ 0 - 12
reflex/components/graphing/plotly.pyi

@@ -7,10 +7,8 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
-import json
 from typing import Any, Dict, List
 from reflex.components.component import NoSSRComponent
-from reflex.utils.serializers import serializer
 from reflex.vars import Var
 
 try:
@@ -185,13 +183,3 @@ class Plotly(PlotlyLib):
             TypeError: If an invalid child is passed.
         """
         ...
-
-try:
-    from plotly.graph_objects import Figure  # type: ignore
-    from plotly.io import to_json
-
-    @serializer
-    def serialize_figure(figure: Figure) -> list: ...  # type: ignore
-
-except ImportError:
-    pass

+ 0 - 9
reflex/components/media/image.pyi

@@ -121,12 +121,3 @@ class Image(ChakraComponent):
             The Image component.
         """
         ...
-
-try:
-    from PIL.Image import Image as Img
-
-    @serializer
-    def serialize_image(image: Img) -> str: ...
-
-except ImportError:
-    pass

+ 1 - 27
reflex/components/next/image.py

@@ -1,10 +1,8 @@
 """Image component from next/image."""
-import base64
-import io
+
 from typing import Any, Dict, Literal, Optional, Union
 
 from reflex.utils import types
-from reflex.utils.serializers import serializer
 from reflex.vars import Var
 
 from .base import NextComponent
@@ -114,27 +112,3 @@ class Image(NextComponent):
             props["src"] = Var.create(value=src, _var_is_string=True)
 
         return super().create(*children, **props)
-
-
-try:
-    from PIL.Image import Image as Img
-
-    @serializer
-    def serialize_image(image: Img) -> str:
-        """Serialize a plotly figure.
-
-        Args:
-            image: The image to serialize.
-
-        Returns:
-            The serialized image.
-        """
-        buff = io.BytesIO()
-        image.save(buff, format=getattr(image, "format", None) or "PNG")
-        image_bytes = buff.getvalue()
-        base64_image = base64.b64encode(image_bytes).decode("utf-8")
-        mime_type = getattr(image, "get_format_mimetype", lambda: "image/png")()
-        return f"data:{mime_type};base64,{base64_image}"
-
-except ImportError:
-    pass

+ 0 - 12
reflex/components/next/image.pyi

@@ -7,11 +7,8 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
-import base64
-import io
 from typing import Any, Dict, Literal, Optional, Union
 from reflex.utils import types
-from reflex.utils.serializers import serializer
 from reflex.vars import Var
 from .base import NextComponent
 
@@ -123,12 +120,3 @@ class Image(NextComponent):
             _type_: _description_
         """
         ...
-
-try:
-    from PIL.Image import Image as Img
-
-    @serializer
-    def serialize_image(image: Img) -> str: ...
-
-except ImportError:
-    pass

+ 83 - 0
reflex/utils/serializers.py

@@ -2,6 +2,7 @@
 
 from __future__ import annotations
 
+import json
 import types as builtin_types
 from datetime import date, datetime, time, timedelta
 from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union, get_type_hints
@@ -227,3 +228,85 @@ def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
         The serialized datetime.
     """
     return str(dt)
+
+
+try:
+    from pandas import DataFrame
+
+    def format_dataframe_values(df: DataFrame) -> List[List[Any]]:
+        """Format dataframe values to a list of lists.
+
+        Args:
+            df: The dataframe to format.
+
+        Returns:
+            The dataframe as a list of lists.
+        """
+        return [
+            [str(d) if isinstance(d, (list, tuple)) else d for d in data]
+            for data in list(df.values.tolist())
+        ]
+
+    @serializer
+    def serialize_dataframe(df: DataFrame) -> dict:
+        """Serialize a pandas dataframe.
+
+        Args:
+            df: The dataframe to serialize.
+
+        Returns:
+            The serialized dataframe.
+        """
+        return {
+            "columns": df.columns.tolist(),
+            "data": format_dataframe_values(df),
+        }
+
+except ImportError:
+    pass
+
+try:
+    from plotly.graph_objects import Figure
+    from plotly.io import to_json
+
+    @serializer
+    def serialize_figure(figure: Figure) -> list:
+        """Serialize a plotly figure.
+
+        Args:
+            figure: The figure to serialize.
+
+        Returns:
+            The serialized figure.
+        """
+        return json.loads(str(to_json(figure)))["data"]
+
+except ImportError:
+    pass
+
+
+try:
+    import base64
+    import io
+
+    from PIL.Image import Image as Img
+
+    @serializer
+    def serialize_image(image: Img) -> str:
+        """Serialize a plotly figure.
+
+        Args:
+            image: The image to serialize.
+
+        Returns:
+            The serialized image.
+        """
+        buff = io.BytesIO()
+        image.save(buff, format=getattr(image, "format", None) or "PNG")
+        image_bytes = buff.getvalue()
+        base64_image = base64.b64encode(image_bytes).decode("utf-8")
+        mime_type = getattr(image, "get_format_mimetype", lambda: "image/png")()
+        return f"data:{mime_type};base64,{base64_image}"
+
+except ImportError:
+    pass

+ 1 - 2
tests/components/datadisplay/test_datatable.py

@@ -4,10 +4,9 @@ import pytest
 import reflex as rx
 from reflex.components.datadisplay.datatable import (
     DataTable,
-    serialize_dataframe,  # type: ignore
 )
 from reflex.utils import types
-from reflex.utils.serializers import serialize
+from reflex.utils.serializers import serialize, serialize_dataframe
 
 
 @pytest.mark.parametrize(

+ 1 - 2
tests/components/graphing/test_plotly.py

@@ -2,8 +2,7 @@ import numpy as np
 import plotly.graph_objects as go
 import pytest
 
-from reflex.components.graphing.plotly import serialize_figure  # type: ignore
-from reflex.utils.serializers import serialize
+from reflex.utils.serializers import serialize, serialize_figure
 
 
 @pytest.fixture

+ 2 - 2
tests/components/media/test_image.py

@@ -5,8 +5,8 @@ import pytest
 from PIL.Image import Image as Img
 
 import reflex as rx
-from reflex.components.next.image import Image, serialize_image  # type: ignore
-from reflex.utils.serializers import serialize
+from reflex.components.next.image import Image  # type: ignore
+from reflex.utils.serializers import serialize, serialize_image
 
 
 @pytest.fixture