소스 검색

feat: add import() api to import exported scenario

trgiangdo 1 년 전
부모
커밋
508e142ea5

+ 7 - 0
taipy/core/_manager/_manager.py

@@ -160,6 +160,13 @@ class _Manager(Generic[EntityType]):
         """
         return cls._repository._export(id, folder_path)
 
+    @classmethod
+    def _import(cls, entity_file: pathlib.Path, version: str, **kwargs):
+        imported_entity = cls._repository._import(entity_file)
+        imported_entity._version = version
+        cls._set(imported_entity)
+        return imported_entity
+
     @classmethod
     def _is_editable(cls, entity: Union[EntityType, str]) -> bool:
         return True

+ 5 - 0
taipy/core/_repository/_filesystem_repository.py

@@ -131,6 +131,11 @@ class _FileSystemRepository(_AbstractRepository[ModelType, Entity]):
 
         shutil.copy2(self.__get_path(entity_id), export_path)
 
+    def _import(self, entity_file_path: pathlib.Path):
+        file_content = self.__read_file(entity_file_path)
+        entity = self.__file_content_to_entity(file_content)
+        return entity
+
     ###########################################
     # ##   Specific or optimized methods   ## #
     ###########################################

+ 15 - 0
taipy/core/_version/_version_manager.py

@@ -9,6 +9,7 @@
 # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations under the License.
 
+import pathlib
 import uuid
 from typing import List, Optional, Union
 
@@ -230,3 +231,17 @@ class _VersionManager(_Manager[_Version]):
     @classmethod
     def _delete_entities_of_multiple_types(cls, _entity_ids):
         raise NotImplementedError
+
+    @classmethod
+    def _import(cls, entity_file: pathlib.Path, version: str, **kwargs):
+        imported_version = cls._repository._import(entity_file)
+
+        comparator_result = Config._comparator._find_conflict_config(  # type: ignore[attr-defined]
+            imported_version.config,
+            Config._applied_config,
+            imported_version.id,
+        )
+        if comparator_result.get(_ComparatorResult.CONFLICTED_SECTION_KEY):
+            raise ConflictedConfigurationError()
+
+        return imported_version

+ 18 - 3
taipy/core/data/_data_manager.py

@@ -26,7 +26,6 @@ from ..exceptions.exceptions import InvalidDataNodeType
 from ..notification import Event, EventEntityType, EventOperation, Notifier, _make_event
 from ..scenario.scenario_id import ScenarioId
 from ..sequence.sequence_id import SequenceId
-from ._abstract_file import _FileDataNodeMixin
 from ._data_fs_repository import _DataFSRepository
 from ._file_datanode_mixin import _FileDataNodeMixin
 from .data_node import DataNode
@@ -182,10 +181,26 @@ class _DataManager(_Manager[DataNode], _VersionMixin):
         else:
             folder = folder_path
 
-        data_export_dir = folder / Config.core.storage_folder
+        data_export_dir = folder / Config.core.storage_folder / os.path.dirname(data_node.path)
         if not data_export_dir.exists():
             data_export_dir.mkdir(parents=True)
 
         data_export_path = data_export_dir / os.path.basename(data_node.path)
         if os.path.exists(data_node.path):
-            shutil.copy(data_node.path, data_export_path)
+            shutil.copy2(data_node.path, data_export_path)
+
+    @classmethod
+    def _import(cls, entity_file: pathlib.Path, version: str, **kwargs):
+        imported_data_node = cls._repository._import(entity_file)
+        imported_data_node._version = version
+        cls._set(imported_data_node)
+
+        if not isinstance(imported_data_node, _FileDataNodeMixin):
+            return imported_data_node
+
+        data_folder: pathlib.Path = pathlib.Path(str(kwargs.get("data_folder")))
+        if not data_folder.exists():
+            return imported_data_node
+
+        if (data_folder / imported_data_node.path).exists():
+            shutil.copy2(data_folder / imported_data_node.path, imported_data_node.path)

+ 35 - 1
taipy/core/exceptions/exceptions.py

@@ -261,7 +261,7 @@ class NonExistingScenarioConfig(Exception):
         self.message = f"Scenario config: {scenario_config_id} does not exist."
 
 
-class InvalidSscenario(Exception):
+class InvalidScenario(Exception):
     """Raised if a Scenario is not a Directed Acyclic Graph."""
 
     def __init__(self, scenario_id: str):
@@ -383,6 +383,40 @@ class ExportFolderAlreadyExists(Exception):
         )
 
 
+class EntitiesToBeImportAlredyExist(Exception):
+    """Raised when entities in the scenario to be imported have already exists"""
+
+    def __init__(self, folder_path):
+        self.message = (
+            f"The import folder {folder_path} contains entities that have already existed."
+            " Please use the 'override' parameter to override those."
+        )
+
+
+class DataToBeImportAlredyExist(Exception):
+    """Raised when data files in the scenario to be imported have already exists"""
+
+    def __init__(self, folder_path):
+        self.message = (
+            f"The import folder {folder_path} contains data files that have already existed."
+            " Please use the 'override' parameter to override those."
+        )
+
+
+class ImportFolderDoesntContainAnyScenario(Exception):
+    """Raised when the import folder doesn't contain any scenario"""
+
+    def __init__(self, folder_path):
+        self.message = f"The import folder {folder_path} doesn't contain any scenario."
+
+
+class ImportScenarioDoesntHaveAVersion(Exception):
+    """Raised when the import folder doesn't contain any scenario"""
+
+    def __init__(self, folder_path):
+        self.message = f"The import scenario in the folder {folder_path} doesn't have a version."
+
+
 class SQLQueryCannotBeExecuted(Exception):
     """Raised when an SQL Query cannot be executed."""
 

+ 99 - 1
taipy/core/taipy.py

@@ -41,7 +41,10 @@ from .data.data_node import DataNode
 from .data.data_node_id import DataNodeId
 from .exceptions.exceptions import (
     DataNodeConfigIsNotGlobal,
+    EntitiesToBeImportAlredyExist,
     ExportFolderAlreadyExists,
+    ImportFolderDoesntContainAnyScenario,
+    ImportScenarioDoesntHaveAVersion,
     InvalidExportPath,
     ModelNotFound,
     NonExistingVersion,
@@ -65,7 +68,7 @@ from .task.task_id import TaskId
 __logger = _TaipyLogger._get_logger()
 
 
-def set(entity: Union[DataNode, Task, Sequence, Scenario, Cycle]):
+def set(entity: Union[DataNode, Task, Sequence, Scenario, Cycle, Submission]):
     """Save or update an entity.
 
     This function allows you to save or update an entity in Taipy.
@@ -1000,6 +1003,101 @@ def export_scenario(
     _VersionManagerFactory._build_manager()._export(scenario.version, folder_path)
 
 
+def import_scenario(folder_path: Union[str, pathlib.Path], override: bool = False):
+    """Import a folder contains an exported scenario into the current Taipy application.
+
+    Args:
+        folder_path (Union[str, pathlib.Path]): The folder path to the scenario to import.
+            If the path doesn't exist, an exception is raised.
+        override (bool): If True, override the entities if existed. Default value is False.
+
+    Return:
+        The imported scenario.
+
+    Raises:
+        FileNotFoundError: If the import folder path does not exist.
+        ImportFolderDoesntContainAnyScenario: If the import folder doesn't contain any scenario.
+        EntitiesToBeImportAlredyExist: If there is any entity in the import folder that has already existed.
+    """
+    entity_managers = {
+        "version": _VersionManagerFactory._build_manager,
+        "scenarios": _ScenarioManagerFactory._build_manager,
+        "jobs": _JobManagerFactory._build_manager,
+        "submissions": _SubmissionManagerFactory._build_manager,
+        "cycles": _CycleManagerFactory._build_manager,
+        "sequences": _SequenceManagerFactory._build_manager,
+        "tasks": _TaskManagerFactory._build_manager,
+        "data_nodes": _DataManagerFactory._build_manager,
+    }
+
+    if isinstance(folder_path, str):
+        folder: pathlib.Path = pathlib.Path(folder_path)
+    else:
+        folder = folder_path
+
+    if not folder.exists():
+        raise FileNotFoundError(f"The import folder '{folder_path}' does not exist.")
+
+    if not (folder / "scenarios").exists():
+        raise ImportFolderDoesntContainAnyScenario(folder_path)
+
+    if not (folder / "version").exists():
+        raise ImportScenarioDoesntHaveAVersion(folder_path)
+    entity_managers["version"]()._import(next((folder / "version").iterdir()), "")
+
+    valid_entity_folders = ["version", "scenarios", "jobs", "submissions", "cycles", "sequences", "tasks", "data_nodes"]
+    valid_data_folder = Config.core.storage_folder
+
+    def check_if_any_importing_entity_exists(log):
+        any_entity_exists = False
+
+        for entity_folder in valid_entity_folders:
+            if not (folder / entity_folder).exists():
+                continue
+
+            manager = entity_managers[entity_folder]()
+
+            for entity_file in (folder / entity_folder).iterdir():
+                if not entity_file.is_file():
+                    continue
+                entity_id = entity_file.stem
+                if manager._exists(entity_id):
+                    log(f"{entity_id} already exists and maybe overridden if imported.")
+                    any_entity_exists = True
+
+        return any_entity_exists
+
+    if override:
+        check_if_any_importing_entity_exists(__logger.warning)
+    else:
+        if check_if_any_importing_entity_exists(__logger.error):
+            raise EntitiesToBeImportAlredyExist(folder_path)
+
+    imported_scenario = None
+
+    for entity_folder in folder.iterdir():
+        if not entity_folder.is_dir() or entity_folder.name not in valid_entity_folders + [valid_data_folder]:
+            __logger.warning(f"{entity_folder} is not a valid Taipy folder and will not be imported.")
+            continue
+
+        # Skip the version folder as it is already checked
+        if entity_folder.name == "version":
+            continue
+
+        entity_type = entity_folder.name
+        manager = entity_managers[entity_type]()
+        for entity_file in entity_folder.iterdir():
+            imported_entity = manager._import(
+                entity_file,
+                version=_VersionManagerFactory._build_manager()._get_latest_version(),
+                data_folder=folder / valid_data_folder,
+            )
+            if entity_type == "scenarios":
+                imported_scenario = imported_entity
+
+    return imported_scenario
+
+
 def get_parents(
     entity: Union[TaskId, DataNodeId, SequenceId, Task, DataNode, Sequence], parent_dict=None
 ) -> Dict[str, Set[_Entity]]:

+ 158 - 0
tests/core/test_taipy/test_import.py

@@ -0,0 +1,158 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.
+
+import os
+import shutil
+
+import pandas as pd
+import pytest
+
+import taipy.core.taipy as tp
+from taipy import Config, Frequency, Scope
+from taipy.core._version._version_manager import _VersionManager
+from taipy.core.cycle._cycle_manager import _CycleManager
+from taipy.core.data._data_manager import _DataManager
+from taipy.core.exceptions.exceptions import (
+    EntitiesToBeImportAlredyExist,
+    ImportFolderDoesntContainAnyScenario,
+    ImportScenarioDoesntHaveAVersion,
+)
+from taipy.core.job._job_manager import _JobManager
+from taipy.core.scenario._scenario_manager import _ScenarioManager
+from taipy.core.submission._submission_manager import _SubmissionManager
+from taipy.core.task._task_manager import _TaskManager
+
+
+@pytest.fixture(scope="function", autouse=True)
+def clean_tmp_folder():
+    shutil.rmtree("./tmp", ignore_errors=True)
+    yield
+    shutil.rmtree("./tmp", ignore_errors=True)
+
+
+def plus_1(x):
+    return x + 1
+
+
+def plus_1_dataframe(x):
+    return pd.DataFrame({"output": [x + 1]})
+
+
+def configure_test_scenario(input_data, frequency=None):
+    input_cfg = Config.configure_data_node(
+        id=f"i_{input_data}", storage_type="pickle", scope=Scope.SCENARIO, default_data=input_data
+    )
+    csv_output_cfg = Config.configure_data_node(id=f"o_{input_data}_csv", storage_type="csv")
+    excel_output_cfg = Config.configure_data_node(id=f"o_{input_data}_excel", storage_type="excel")
+    parquet_output_cfg = Config.configure_data_node(id=f"o_{input_data}_parquet", storage_type="parquet")
+    json_output_cfg = Config.configure_data_node(id=f"o_{input_data}_json", storage_type="json")
+
+    csv_task_cfg = Config.configure_task(f"t_{input_data}_csv", plus_1_dataframe, input_cfg, csv_output_cfg)
+    excel_task_cfg = Config.configure_task(f"t_{input_data}_excel", plus_1_dataframe, input_cfg, excel_output_cfg)
+    parquet_task_cfg = Config.configure_task(f"t_{input_data}_parquet", plus_1_dataframe, input_cfg, parquet_output_cfg)
+    json_task_cfg = Config.configure_task(f"t_{input_data}_json", plus_1, input_cfg, json_output_cfg)
+    scenario_cfg = Config.configure_scenario(
+        id=f"s_{input_data}",
+        task_configs=[csv_task_cfg, excel_task_cfg, parquet_task_cfg, json_task_cfg],
+        frequency=frequency,
+    )
+
+    return scenario_cfg
+
+
+def export_test_scenario(scenario_cfg, folder_path="./tmp/exp_scenario", override=False, include_data=False):
+    scenario = tp.create_scenario(scenario_cfg)
+    tp.submit(scenario)
+
+    # Export the submitted scenario
+    tp.export_scenario(scenario.id, folder_path, override, include_data)
+    return scenario
+
+
+def test_import_scenario_without_data(init_managers):
+    scenario_cfg = configure_test_scenario(1, frequency=Frequency.DAILY)
+    scenario = export_test_scenario(scenario_cfg)
+
+    init_managers()
+
+    assert _ScenarioManager._get_all() == []
+    imported_scenario = tp.import_scenario("./tmp/exp_scenario")
+
+    # The imported scenario should be the same as the exported scenario
+    assert _ScenarioManager._get_all() == [imported_scenario]
+    assert imported_scenario == scenario
+
+    # All entities belonging to the scenario should be imported
+    assert len(_CycleManager._get_all()) == 1
+    assert len(_TaskManager._get_all()) == 4
+    assert len(_DataManager._get_all()) == 5
+    assert len(_JobManager._get_all()) == 4
+    assert len(_SubmissionManager._get_all()) == 1
+    assert len(_VersionManager._get_all()) == 1
+
+
+def test_import_scenario_with_data(init_managers):
+    scenario_cfg = configure_test_scenario(1, frequency=Frequency.DAILY)
+    export_test_scenario(scenario_cfg, include_data=True)
+
+    init_managers()
+
+    assert _ScenarioManager._get_all() == []
+    imported_scenario = tp.import_scenario("./tmp/exp_scenario")
+
+    # All data of all data nodes should be imported
+    assert all(os.path.exists(dn.path) for dn in imported_scenario.data_nodes.values())
+
+
+def test_import_scenario_when_entities_are_already_existed(caplog):
+    scenario_cfg = configure_test_scenario(1, frequency=Frequency.DAILY)
+    export_test_scenario(scenario_cfg)
+
+    caplog.clear()
+
+    # Import the scenario when the old entities still exist
+    with pytest.raises(EntitiesToBeImportAlredyExist):
+        tp.import_scenario("./tmp/exp_scenario")
+    assert all(log.levelname == "ERROR" for log in caplog.records[1:])
+
+    caplog.clear()
+
+    # Import with override flag
+    assert len(_ScenarioManager._get_all()) == 1
+    tp.import_scenario("./tmp/exp_scenario", override=True)
+    assert all(log.levelname == "WARNING" for log in caplog.records[1:])
+
+    # The scenario is overridden
+    assert len(_ScenarioManager._get_all()) == 1
+
+
+def test_import_a_non_exists_folder():
+    scenario_cfg = configure_test_scenario(1, frequency=Frequency.DAILY)
+    export_test_scenario(scenario_cfg)
+
+    with pytest.raises(FileNotFoundError):
+        tp.import_scenario("non_exists_folder")
+
+
+def test_import_an_empty_folder(tmpdir_factory):
+    empty_folder = tmpdir_factory.mktemp("empty_folder").strpath
+
+    with pytest.raises(ImportFolderDoesntContainAnyScenario):
+        tp.import_scenario(empty_folder)
+
+
+def test_import_with_no_version():
+    scenario_cfg = configure_test_scenario(1, frequency=Frequency.DAILY)
+    export_test_scenario(scenario_cfg)
+    shutil.rmtree("./tmp/exp_scenario/version")
+
+    with pytest.raises(ImportScenarioDoesntHaveAVersion):
+        tp.import_scenario("./tmp/exp_scenario")