Explorar el Código

Merge pull request #1822 from Avaiga/fix/#680-not-able-to-read-s3dn-except-from-text

Fix/#680 - S3ObjectDtaNode.read should return the binary form of the data
Đỗ Trường Giang hace 7 meses
padre
commit
2410d6a30f
Se han modificado 2 ficheros con 113 adiciones y 8 borrados
  1. 1 1
      taipy/core/data/aws_s3.py
  2. 112 7
      tests/core/data/test_aws_s3_data_node.py

+ 1 - 1
taipy/core/data/aws_s3.py

@@ -152,7 +152,7 @@ class S3ObjectDataNode(DataNode):
             Bucket=properties[self.__AWS_STORAGE_BUCKET_NAME],
             Key=properties[self.__AWS_S3_OBJECT_KEY],
         )
-        return aws_s3_object["Body"].read().decode("utf-8")
+        return aws_s3_object["Body"].read()
 
     def _write(self, data: Any):
         properties = self.properties

+ 112 - 7
tests/core/data/test_aws_s3_data_node.py

@@ -9,9 +9,16 @@
 # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations under the License.
 
+import os
+import pathlib
+import pickle
+from io import BytesIO
+
 import boto3
+import pandas as pd
 import pytest
 from moto import mock_s3
+from pandas.testing import assert_frame_equal
 
 from taipy.config import Config
 from taipy.config.common.scope import Scope
@@ -52,9 +59,8 @@ class TestS3ObjectDataNode:
         ],
     )
     @pytest.mark.parametrize("properties", __properties)
-    def test_write(self, properties, data):
+    def test_write_text(self, properties, data):
         bucket_name = properties["aws_s3_bucket_name"]
-        # Create an S3 client
         s3_client = boto3.client("s3")
         # Create a bucket
         s3_client.create_bucket(Bucket=bucket_name)
@@ -62,13 +68,46 @@ class TestS3ObjectDataNode:
         object_key = properties["aws_s3_object_key"]
         # Create Taipy S3ObjectDataNode
         aws_s3_object_dn = S3ObjectDataNode("foo_aws_s3", Scope.SCENARIO, properties=properties)
-        # Put an object in the bucket with Taipy
         aws_s3_object_dn._write(data)
         # Read the object with boto3
         response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
 
         assert response["Body"].read().decode("utf-8") == "Hello, write world!"
 
+    @mock_s3
+    @pytest.mark.parametrize(
+        "data_path",
+        [
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.csv"),
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.xlsx"),
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.p"),
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.parquet"),
+        ],
+    )
+    @pytest.mark.parametrize("properties", __properties)
+    def test_write_binary_data(self, properties, data_path):
+        bucket_name = properties["aws_s3_bucket_name"]
+        s3_client = boto3.client("s3")
+        s3_client.create_bucket(Bucket=bucket_name)
+        object_key = properties["aws_s3_object_key"]
+
+        aws_s3_object_dn = S3ObjectDataNode("foo_aws_s3", Scope.SCENARIO, properties=properties)
+        with open(data_path, "rb") as file_binary_data:
+            aws_s3_object_dn._write(file_binary_data)
+
+        # Read the object with boto3
+        response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
+        s3_data = response["Body"].read()
+
+        if data_path.endswith(".csv"):
+            assert_frame_equal(pd.read_csv(BytesIO(s3_data)), pd.read_csv(data_path))
+        elif data_path.endswith(".xlsx"):
+            assert_frame_equal(pd.read_excel(BytesIO(s3_data)), pd.read_excel(data_path))
+        elif data_path.endswith(".parquet"):
+            assert_frame_equal(pd.read_parquet(BytesIO(s3_data)), pd.read_parquet(data_path))
+        elif data_path.endswith(".p"):
+            assert pickle.loads(s3_data) == pickle.load(open(data_path, "rb"))
+
     @mock_s3
     @pytest.mark.parametrize(
         "data",
@@ -77,13 +116,11 @@ class TestS3ObjectDataNode:
         ],
     )
     @pytest.mark.parametrize("properties", __properties)
-    def test_read(self, properties, data):
+    def test_read_text(self, properties, data):
         bucket_name = properties["aws_s3_bucket_name"]
-        # Create an S3 client
         client = boto3.client("s3")
         # Create a bucket
         client.create_bucket(Bucket=bucket_name)
-        # Put an object in the bucket with boto3
         object_key = properties["aws_s3_object_key"]
         object_body = "Hello, read world!"
         client.put_object(Body=object_body, Bucket=bucket_name, Key=object_key)
@@ -92,4 +129,72 @@ class TestS3ObjectDataNode:
         # Read the Object from bucket with Taipy
         response = aws_s3_object_dn._read()
 
-        assert response == data
+        assert response.decode("utf-8") == data
+
+    @mock_s3
+    @pytest.mark.parametrize(
+        "data_path",
+        [
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.csv"),
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.xlsx"),
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.p"),
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.parquet"),
+        ],
+    )
+    @pytest.mark.parametrize("properties", __properties)
+    def test_read_binary_data(self, properties, data_path):
+        bucket_name = properties["aws_s3_bucket_name"]
+        client = boto3.client("s3")
+        client.create_bucket(Bucket=bucket_name)
+        object_key = properties["aws_s3_object_key"]
+
+        with open(data_path, "rb") as file_binary_data:
+            client.put_object(Body=file_binary_data, Bucket=bucket_name, Key=object_key)
+
+        # Create Taipy S3ObjectDataNode
+        aws_s3_object_dn = S3ObjectDataNode("foo_aws_s3", Scope.SCENARIO, properties=properties)
+        # Read the Object from bucket with Taipy
+        read_data = aws_s3_object_dn._read()
+
+        if data_path.endswith(".csv"):
+            assert_frame_equal(pd.read_csv(BytesIO(read_data)), pd.read_csv(data_path))
+        elif data_path.endswith(".xlsx"):
+            assert_frame_equal(pd.read_excel(BytesIO(read_data)), pd.read_excel(data_path))
+        elif data_path.endswith(".parquet"):
+            assert_frame_equal(pd.read_parquet(BytesIO(read_data)), pd.read_parquet(data_path))
+        elif data_path.endswith(".p"):
+            assert pickle.loads(read_data) == pickle.load(open(data_path, "rb"))
+
+    @mock_s3
+    @pytest.mark.parametrize(
+        "data_path",
+        [
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.csv"),
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.xlsx"),
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.p"),
+            os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample", "example.parquet"),
+        ],
+    )
+    @pytest.mark.parametrize("properties", __properties)
+    def test_read_file_data(self, properties, data_path):
+        bucket_name = properties["aws_s3_bucket_name"]
+        client = boto3.client("s3")
+        client.create_bucket(Bucket=bucket_name)
+        object_key = properties["aws_s3_object_key"]
+
+        # Upload file to S3 bucket
+        client.upload_file(Filename=data_path, Bucket=bucket_name, Key=object_key)
+
+        # Create Taipy S3ObjectDataNode
+        aws_s3_object_dn = S3ObjectDataNode("foo_aws_s3", Scope.SCENARIO, properties=properties)
+        # Read the file from bucket with Taipy should return the binary data of the uploaded file
+        read_data = aws_s3_object_dn._read()
+
+        if data_path.endswith(".csv"):
+            assert_frame_equal(pd.read_csv(BytesIO(read_data)), pd.read_csv(data_path))
+        elif data_path.endswith(".xlsx"):
+            assert_frame_equal(pd.read_excel(BytesIO(read_data)), pd.read_excel(data_path))
+        elif data_path.endswith(".parquet"):
+            assert_frame_equal(pd.read_parquet(BytesIO(read_data)), pd.read_parquet(data_path))
+        elif data_path.endswith(".p"):
+            assert pickle.loads(read_data) == pickle.load(open(data_path, "rb"))