Răsfoiți Sursa

fixed bug in data model not deserialize parent_ids and add tests to check the changes (#1148)

Co-authored-by: Toan Quach <shiro@Shiros-MacBook-Pro.local>
Co-authored-by: Jean-Robin <jeanrobin.medori@avaiga.com>
Toan Quach 1 an în urmă
părinte
comite
546ad8a9d9

+ 1 - 1
taipy/core/data/_data_model.py

@@ -69,7 +69,7 @@ class _DataNodeModel(_BaseModel):
             scope=Scope._from_repr(data["scope"]),
             storage_type=data["storage_type"],
             owner_id=data.get("owner_id"),
-            parent_ids=data.get("parent_ids", []),
+            parent_ids=_BaseModel._deserialize_attribute(data.get("parent_ids", [])),
             last_edit_date=data.get("last_edit_date"),
             edits=_BaseModel._deserialize_attribute(data["edits"]),
             version=data["version"],

+ 5 - 5
taipy/core/task/task.py

@@ -71,8 +71,8 @@ class Task(_Entity, _Labeled):
         self.id = id or TaskId(self.__ID_SEPARATOR.join([self._ID_PREFIX, self.config_id, str(uuid.uuid4())]))
         self._owner_id = owner_id
         self._parent_ids = parent_ids or set()
-        self.__input = {dn.config_id: dn for dn in input or []}
-        self.__output = {dn.config_id: dn for dn in output or []}
+        self._input = {dn.config_id: dn for dn in input or []}
+        self._output = {dn.config_id: dn for dn in output or []}
         self._function = function
         self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
         self._skippable = skippable
@@ -126,11 +126,11 @@ class Task(_Entity, _Labeled):
 
     @property
     def input(self) -> Dict[str, DataNode]:
-        return self.__input
+        return self._input
 
     @property
     def output(self) -> Dict[str, DataNode]:
-        return self.__output
+        return self._output
 
     @property
     def data_nodes(self) -> Dict[str, DataNode]:
@@ -164,7 +164,7 @@ class Task(_Entity, _Labeled):
             The lowest scope present in input and output data nodes or GLOBAL if there are
                 either no input or no output.
         """
-        data_nodes = list(self.__input.values()) + list(self.__output.values())
+        data_nodes = list(self._input.values()) + list(self._output.values())
         return Scope(min(dn.scope for dn in data_nodes)) if len(data_nodes) != 0 else Scope.GLOBAL
 
     @property

+ 10 - 3
tests/core/cycle/test_cycle_repositories.py

@@ -21,12 +21,19 @@ from taipy.core.exceptions import ModelNotFound
 
 class TestCycleRepositories:
     @pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
-    def test_save_and_load(self, cycle, repo, init_sql_repo):
+    def test_save_and_load(self, cycle: Cycle, repo, init_sql_repo):
         repository = repo()
         repository._save(cycle)
 
-        obj = repository._load(cycle.id)
-        assert isinstance(obj, Cycle)
+        loaded_cycle = repository._load(cycle.id)
+        assert isinstance(loaded_cycle, Cycle)
+        assert cycle._frequency == loaded_cycle._frequency
+        assert cycle._creation_date == loaded_cycle._creation_date
+        assert cycle._start_date == loaded_cycle._start_date
+        assert cycle._end_date == loaded_cycle._end_date
+        assert cycle._name == loaded_cycle._name
+        assert cycle.id == loaded_cycle.id
+        assert cycle._properties == loaded_cycle._properties
 
     @pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
     def test_exists(self, cycle, repo, init_sql_repo):

+ 16 - 3
tests/core/data/test_data_repositories.py

@@ -21,12 +21,25 @@ from taipy.core.exceptions import ModelNotFound
 
 class TestDataNodeRepository:
     @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
-    def test_save_and_load(self, data_node, repo, init_sql_repo):
+    def test_save_and_load(self, data_node: DataNode, repo, init_sql_repo):
         repository = repo()
         repository._save(data_node)
 
-        obj = repository._load(data_node.id)
-        assert isinstance(obj, DataNode)
+        loaded_data_node = repository._load(data_node.id)
+        assert isinstance(loaded_data_node, DataNode)
+        assert data_node.id == loaded_data_node.id
+        assert data_node._config_id == loaded_data_node._config_id
+        assert data_node._owner_id == loaded_data_node._owner_id
+        assert data_node._parent_ids == loaded_data_node._parent_ids
+        assert data_node._scope == loaded_data_node._scope
+        assert data_node._last_edit_date == loaded_data_node._last_edit_date
+        assert data_node._edit_in_progress == loaded_data_node._edit_in_progress
+        assert data_node._version == loaded_data_node._version
+        assert data_node._validity_period == loaded_data_node._validity_period
+        assert data_node._editor_id == loaded_data_node._editor_id
+        assert data_node._editor_expiration_date == loaded_data_node._editor_expiration_date
+        assert data_node._edits == loaded_data_node._edits
+        assert data_node._properties == loaded_data_node._properties
 
     @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
     def test_exists(self, data_node, repo, init_sql_repo):

+ 14 - 3
tests/core/scenario/test_scenario_repositories.py

@@ -21,12 +21,23 @@ from taipy.core.scenario.scenario import Scenario, ScenarioId
 
 class TestScenarioFSRepository:
     @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
-    def test_save_and_load(self, scenario, repo, init_sql_repo):
+    def test_save_and_load(self, scenario: Scenario, repo, init_sql_repo):
         repository = repo()
         repository._save(scenario)
 
-        obj = repository._load(scenario.id)
-        assert isinstance(obj, Scenario)
+        loaded_scenario = repository._load(scenario.id)
+        assert isinstance(loaded_scenario, Scenario)
+        assert scenario._config_id == loaded_scenario._config_id
+        assert scenario.id == loaded_scenario.id
+        assert scenario._tasks == loaded_scenario._tasks
+        assert scenario._additional_data_nodes == loaded_scenario._additional_data_nodes
+        assert scenario._creation_date == loaded_scenario._creation_date
+        assert scenario._cycle == loaded_scenario._cycle
+        assert scenario._primary_scenario == loaded_scenario._primary_scenario
+        assert scenario._tags == loaded_scenario._tags
+        assert scenario._properties == loaded_scenario._properties
+        assert scenario._sequences == loaded_scenario._sequences
+        assert scenario._version == loaded_scenario._version
 
     @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
     def test_exists(self, scenario, repo, init_sql_repo):

+ 108 - 76
tests/core/task/test_task_repositories.py

@@ -13,6 +13,8 @@ import os
 
 import pytest
 
+from taipy.config.config import Config
+from taipy.core.data._data_fs_repository import _DataFSRepository
 from taipy.core.data._data_sql_repository import _DataSQLRepository
 from taipy.core.exceptions import ModelNotFound
 from taipy.core.task._task_fs_repository import _TaskFSRepository
@@ -21,149 +23,179 @@ from taipy.core.task.task import Task, TaskId
 
 
 class TestTaskFSRepository:
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_save_and_load(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_save_and_load(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
-        repository._save(task)
-
-        obj = repository._load(task.id)
-        assert isinstance(obj, Task)
-
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_exists(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+        task_repository._save(task)
+
+        loaded_task = task_repository._load(task.id)
+        assert isinstance(loaded_task, Task)
+        assert task._config_id == loaded_task._config_id
+        assert task.id == loaded_task.id
+        assert task._owner_id == loaded_task._owner_id
+        assert task._parent_ids == loaded_task._parent_ids
+        assert task._input == loaded_task._input
+        assert task._output == loaded_task._output
+        assert task._function == loaded_task._function
+        assert task._version == loaded_task._version
+        assert task._skippable == loaded_task._skippable
+        assert task._properties == loaded_task._properties
+
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_exists(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
-        repository._save(task)
+        task_repository._save(task)
 
-        assert repository._exists(task.id)
-        assert not repository._exists("not-existed-task")
+        assert task_repository._exists(task.id)
+        assert not task_repository._exists("not-existed-task")
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_load_all(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_load_all(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
         for i in range(10):
             task.id = TaskId(f"task-{i}")
-            repository._save(task)
-        data_nodes = repository._load_all()
+            task_repository._save(task)
+        data_nodes = task_repository._load_all()
 
         assert len(data_nodes) == 10
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_load_all_with_filters(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_load_all_with_filters(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
         for i in range(10):
             task.id = TaskId(f"task-{i}")
             task._owner_id = f"owner-{i}"
-            repository._save(task)
-        objs = repository._load_all(filters=[{"owner_id": "owner-2"}])
+            task_repository._save(task)
+        objs = task_repository._load_all(filters=[{"owner_id": "owner-2"}])
 
         assert len(objs) == 1
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_delete(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_delete(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
-        repository._save(task)
+        task_repository._save(task)
 
-        repository._delete(task.id)
+        task_repository._delete(task.id)
 
         with pytest.raises(ModelNotFound):
-            repository._load(task.id)
-
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_delete_all(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+            task_repository._load(task.id)
+
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_delete_all(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
         for i in range(10):
             task.id = TaskId(f"task-{i}")
-            repository._save(task)
+            task_repository._save(task)
 
-        assert len(repository._load_all()) == 10
+        assert len(task_repository._load_all()) == 10
 
-        repository._delete_all()
+        task_repository._delete_all()
 
-        assert len(repository._load_all()) == 0
+        assert len(task_repository._load_all()) == 0
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_delete_many(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_delete_many(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
         for i in range(10):
             task.id = TaskId(f"task-{i}")
-            repository._save(task)
+            task_repository._save(task)
 
-        objs = repository._load_all()
+        objs = task_repository._load_all()
         assert len(objs) == 10
         ids = [x.id for x in objs[:3]]
-        repository._delete_many(ids)
+        task_repository._delete_many(ids)
 
-        assert len(repository._load_all()) == 7
+        assert len(task_repository._load_all()) == 7
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_delete_by(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_delete_by(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
         # Create 5 entities with version 1.0 and 5 entities with version 2.0
         for i in range(10):
             task.id = TaskId(f"task-{i}")
             task._version = f"{(i+1) // 5}.0"
-            repository._save(task)
+            task_repository._save(task)
 
-        objs = repository._load_all()
+        objs = task_repository._load_all()
         assert len(objs) == 10
-        repository._delete_by("version", "1.0")
+        task_repository._delete_by("version", "1.0")
 
-        assert len(repository._load_all()) == 5
+        assert len(task_repository._load_all()) == 5
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_search(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_search(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node], version="random_version_number")
 
         for i in range(10):
             task.id = TaskId(f"task-{i}")
             task._owner_id = f"owner-{i}"
-            repository._save(task)
+            task_repository._save(task)
 
-        assert len(repository._load_all()) == 10
+        assert len(task_repository._load_all()) == 10
 
-        objs = repository._search("owner_id", "owner-2")
+        objs = task_repository._search("owner_id", "owner-2")
         assert len(objs) == 1
         assert isinstance(objs[0], Task)
 
-        objs = repository._search("owner_id", "owner-2", filters=[{"version": "random_version_number"}])
+        objs = task_repository._search("owner_id", "owner-2", filters=[{"version": "random_version_number"}])
         assert len(objs) == 1
         assert isinstance(objs[0], Task)
 
-        assert repository._search("owner_id", "owner-2", filters=[{"version": "non_existed_version"}]) == []
+        assert task_repository._search("owner_id", "owner-2", filters=[{"version": "non_existed_version"}]) == []
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_export(self, tmpdir, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_export(self, tmpdir, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
-        repository._save(task)
+        task_repository._save(task)
 
-        repository._export(task.id, tmpdir.strpath)
-        dir_path = repository.dir_path if repo == _TaskFSRepository else os.path.join(tmpdir.strpath, "task")
+        task_repository._export(task.id, tmpdir.strpath)
+        dir_path = task_repository.dir_path if repo[0] == _TaskFSRepository else os.path.join(tmpdir.strpath, "task")
 
         assert os.path.exists(os.path.join(dir_path, f"{task.id}.json"))