Ver Fonte

fixed bug in data model not deserialize parent_ids and add tests to check the changes (#1148)

Co-authored-by: Toan Quach <shiro@Shiros-MacBook-Pro.local>
Co-authored-by: Jean-Robin <jeanrobin.medori@avaiga.com>
Toan Quach há 1 ano atrás
pai
commit
546ad8a9d9

+ 1 - 1
taipy/core/data/_data_model.py

@@ -69,7 +69,7 @@ class _DataNodeModel(_BaseModel):
             scope=Scope._from_repr(data["scope"]),
             scope=Scope._from_repr(data["scope"]),
             storage_type=data["storage_type"],
             storage_type=data["storage_type"],
             owner_id=data.get("owner_id"),
             owner_id=data.get("owner_id"),
-            parent_ids=data.get("parent_ids", []),
+            parent_ids=_BaseModel._deserialize_attribute(data.get("parent_ids", [])),
             last_edit_date=data.get("last_edit_date"),
             last_edit_date=data.get("last_edit_date"),
             edits=_BaseModel._deserialize_attribute(data["edits"]),
             edits=_BaseModel._deserialize_attribute(data["edits"]),
             version=data["version"],
             version=data["version"],

+ 5 - 5
taipy/core/task/task.py

@@ -71,8 +71,8 @@ class Task(_Entity, _Labeled):
         self.id = id or TaskId(self.__ID_SEPARATOR.join([self._ID_PREFIX, self.config_id, str(uuid.uuid4())]))
         self.id = id or TaskId(self.__ID_SEPARATOR.join([self._ID_PREFIX, self.config_id, str(uuid.uuid4())]))
         self._owner_id = owner_id
         self._owner_id = owner_id
         self._parent_ids = parent_ids or set()
         self._parent_ids = parent_ids or set()
-        self.__input = {dn.config_id: dn for dn in input or []}
-        self.__output = {dn.config_id: dn for dn in output or []}
+        self._input = {dn.config_id: dn for dn in input or []}
+        self._output = {dn.config_id: dn for dn in output or []}
         self._function = function
         self._function = function
         self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
         self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
         self._skippable = skippable
         self._skippable = skippable
@@ -126,11 +126,11 @@ class Task(_Entity, _Labeled):
 
 
     @property
     @property
     def input(self) -> Dict[str, DataNode]:
     def input(self) -> Dict[str, DataNode]:
-        return self.__input
+        return self._input
 
 
     @property
     @property
     def output(self) -> Dict[str, DataNode]:
     def output(self) -> Dict[str, DataNode]:
-        return self.__output
+        return self._output
 
 
     @property
     @property
     def data_nodes(self) -> Dict[str, DataNode]:
     def data_nodes(self) -> Dict[str, DataNode]:
@@ -164,7 +164,7 @@ class Task(_Entity, _Labeled):
             The lowest scope present in input and output data nodes or GLOBAL if there are
             The lowest scope present in input and output data nodes or GLOBAL if there are
                 either no input or no output.
                 either no input or no output.
         """
         """
-        data_nodes = list(self.__input.values()) + list(self.__output.values())
+        data_nodes = list(self._input.values()) + list(self._output.values())
         return Scope(min(dn.scope for dn in data_nodes)) if len(data_nodes) != 0 else Scope.GLOBAL
         return Scope(min(dn.scope for dn in data_nodes)) if len(data_nodes) != 0 else Scope.GLOBAL
 
 
     @property
     @property

+ 10 - 3
tests/core/cycle/test_cycle_repositories.py

@@ -21,12 +21,19 @@ from taipy.core.exceptions import ModelNotFound
 
 
 class TestCycleRepositories:
 class TestCycleRepositories:
     @pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
     @pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
-    def test_save_and_load(self, cycle, repo, init_sql_repo):
+    def test_save_and_load(self, cycle: Cycle, repo, init_sql_repo):
         repository = repo()
         repository = repo()
         repository._save(cycle)
         repository._save(cycle)
 
 
-        obj = repository._load(cycle.id)
-        assert isinstance(obj, Cycle)
+        loaded_cycle = repository._load(cycle.id)
+        assert isinstance(loaded_cycle, Cycle)
+        assert cycle._frequency == loaded_cycle._frequency
+        assert cycle._creation_date == loaded_cycle._creation_date
+        assert cycle._start_date == loaded_cycle._start_date
+        assert cycle._end_date == loaded_cycle._end_date
+        assert cycle._name == loaded_cycle._name
+        assert cycle.id == loaded_cycle.id
+        assert cycle._properties == loaded_cycle._properties
 
 
     @pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
     @pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
     def test_exists(self, cycle, repo, init_sql_repo):
     def test_exists(self, cycle, repo, init_sql_repo):

+ 16 - 3
tests/core/data/test_data_repositories.py

@@ -21,12 +21,25 @@ from taipy.core.exceptions import ModelNotFound
 
 
 class TestDataNodeRepository:
 class TestDataNodeRepository:
     @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
     @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
-    def test_save_and_load(self, data_node, repo, init_sql_repo):
+    def test_save_and_load(self, data_node: DataNode, repo, init_sql_repo):
         repository = repo()
         repository = repo()
         repository._save(data_node)
         repository._save(data_node)
 
 
-        obj = repository._load(data_node.id)
-        assert isinstance(obj, DataNode)
+        loaded_data_node = repository._load(data_node.id)
+        assert isinstance(loaded_data_node, DataNode)
+        assert data_node.id == loaded_data_node.id
+        assert data_node._config_id == loaded_data_node._config_id
+        assert data_node._owner_id == loaded_data_node._owner_id
+        assert data_node._parent_ids == loaded_data_node._parent_ids
+        assert data_node._scope == loaded_data_node._scope
+        assert data_node._last_edit_date == loaded_data_node._last_edit_date
+        assert data_node._edit_in_progress == loaded_data_node._edit_in_progress
+        assert data_node._version == loaded_data_node._version
+        assert data_node._validity_period == loaded_data_node._validity_period
+        assert data_node._editor_id == loaded_data_node._editor_id
+        assert data_node._editor_expiration_date == loaded_data_node._editor_expiration_date
+        assert data_node._edits == loaded_data_node._edits
+        assert data_node._properties == loaded_data_node._properties
 
 
     @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
     @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
     def test_exists(self, data_node, repo, init_sql_repo):
     def test_exists(self, data_node, repo, init_sql_repo):

+ 14 - 3
tests/core/scenario/test_scenario_repositories.py

@@ -21,12 +21,23 @@ from taipy.core.scenario.scenario import Scenario, ScenarioId
 
 
 class TestScenarioFSRepository:
 class TestScenarioFSRepository:
     @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
     @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
-    def test_save_and_load(self, scenario, repo, init_sql_repo):
+    def test_save_and_load(self, scenario: Scenario, repo, init_sql_repo):
         repository = repo()
         repository = repo()
         repository._save(scenario)
         repository._save(scenario)
 
 
-        obj = repository._load(scenario.id)
-        assert isinstance(obj, Scenario)
+        loaded_scenario = repository._load(scenario.id)
+        assert isinstance(loaded_scenario, Scenario)
+        assert scenario._config_id == loaded_scenario._config_id
+        assert scenario.id == loaded_scenario.id
+        assert scenario._tasks == loaded_scenario._tasks
+        assert scenario._additional_data_nodes == loaded_scenario._additional_data_nodes
+        assert scenario._creation_date == loaded_scenario._creation_date
+        assert scenario._cycle == loaded_scenario._cycle
+        assert scenario._primary_scenario == loaded_scenario._primary_scenario
+        assert scenario._tags == loaded_scenario._tags
+        assert scenario._properties == loaded_scenario._properties
+        assert scenario._sequences == loaded_scenario._sequences
+        assert scenario._version == loaded_scenario._version
 
 
     @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
     @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
     def test_exists(self, scenario, repo, init_sql_repo):
     def test_exists(self, scenario, repo, init_sql_repo):

+ 108 - 76
tests/core/task/test_task_repositories.py

@@ -13,6 +13,8 @@ import os
 
 
 import pytest
 import pytest
 
 
+from taipy.config.config import Config
+from taipy.core.data._data_fs_repository import _DataFSRepository
 from taipy.core.data._data_sql_repository import _DataSQLRepository
 from taipy.core.data._data_sql_repository import _DataSQLRepository
 from taipy.core.exceptions import ModelNotFound
 from taipy.core.exceptions import ModelNotFound
 from taipy.core.task._task_fs_repository import _TaskFSRepository
 from taipy.core.task._task_fs_repository import _TaskFSRepository
@@ -21,149 +23,179 @@ from taipy.core.task.task import Task, TaskId
 
 
 
 
 class TestTaskFSRepository:
 class TestTaskFSRepository:
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_save_and_load(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_save_and_load(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
 
-        repository._save(task)
-
-        obj = repository._load(task.id)
-        assert isinstance(obj, Task)
-
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_exists(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+        task_repository._save(task)
+
+        loaded_task = task_repository._load(task.id)
+        assert isinstance(loaded_task, Task)
+        assert task._config_id == loaded_task._config_id
+        assert task.id == loaded_task.id
+        assert task._owner_id == loaded_task._owner_id
+        assert task._parent_ids == loaded_task._parent_ids
+        assert task._input == loaded_task._input
+        assert task._output == loaded_task._output
+        assert task._function == loaded_task._function
+        assert task._version == loaded_task._version
+        assert task._skippable == loaded_task._skippable
+        assert task._properties == loaded_task._properties
+
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_exists(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
 
-        repository._save(task)
+        task_repository._save(task)
 
 
-        assert repository._exists(task.id)
-        assert not repository._exists("not-existed-task")
+        assert task_repository._exists(task.id)
+        assert not task_repository._exists("not-existed-task")
 
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_load_all(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_load_all(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
 
         for i in range(10):
         for i in range(10):
             task.id = TaskId(f"task-{i}")
             task.id = TaskId(f"task-{i}")
-            repository._save(task)
-        data_nodes = repository._load_all()
+            task_repository._save(task)
+        data_nodes = task_repository._load_all()
 
 
         assert len(data_nodes) == 10
         assert len(data_nodes) == 10
 
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_load_all_with_filters(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_load_all_with_filters(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
 
         for i in range(10):
         for i in range(10):
             task.id = TaskId(f"task-{i}")
             task.id = TaskId(f"task-{i}")
             task._owner_id = f"owner-{i}"
             task._owner_id = f"owner-{i}"
-            repository._save(task)
-        objs = repository._load_all(filters=[{"owner_id": "owner-2"}])
+            task_repository._save(task)
+        objs = task_repository._load_all(filters=[{"owner_id": "owner-2"}])
 
 
         assert len(objs) == 1
         assert len(objs) == 1
 
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_delete(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_delete(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
         task = Task("task_config_id", {}, print, [data_node], [data_node])
-        repository._save(task)
+        task_repository._save(task)
 
 
-        repository._delete(task.id)
+        task_repository._delete(task.id)
 
 
         with pytest.raises(ModelNotFound):
         with pytest.raises(ModelNotFound):
-            repository._load(task.id)
-
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_delete_all(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+            task_repository._load(task.id)
+
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_delete_all(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
 
         for i in range(10):
         for i in range(10):
             task.id = TaskId(f"task-{i}")
             task.id = TaskId(f"task-{i}")
-            repository._save(task)
+            task_repository._save(task)
 
 
-        assert len(repository._load_all()) == 10
+        assert len(task_repository._load_all()) == 10
 
 
-        repository._delete_all()
+        task_repository._delete_all()
 
 
-        assert len(repository._load_all()) == 0
+        assert len(task_repository._load_all()) == 0
 
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_delete_many(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_delete_many(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
 
         for i in range(10):
         for i in range(10):
             task.id = TaskId(f"task-{i}")
             task.id = TaskId(f"task-{i}")
-            repository._save(task)
+            task_repository._save(task)
 
 
-        objs = repository._load_all()
+        objs = task_repository._load_all()
         assert len(objs) == 10
         assert len(objs) == 10
         ids = [x.id for x in objs[:3]]
         ids = [x.id for x in objs[:3]]
-        repository._delete_many(ids)
+        task_repository._delete_many(ids)
 
 
-        assert len(repository._load_all()) == 7
+        assert len(task_repository._load_all()) == 7
 
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_delete_by(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_delete_by(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
         task = Task("task_config_id", {}, print, [data_node], [data_node])
 
 
         # Create 5 entities with version 1.0 and 5 entities with version 2.0
         # Create 5 entities with version 1.0 and 5 entities with version 2.0
         for i in range(10):
         for i in range(10):
             task.id = TaskId(f"task-{i}")
             task.id = TaskId(f"task-{i}")
             task._version = f"{(i+1) // 5}.0"
             task._version = f"{(i+1) // 5}.0"
-            repository._save(task)
+            task_repository._save(task)
 
 
-        objs = repository._load_all()
+        objs = task_repository._load_all()
         assert len(objs) == 10
         assert len(objs) == 10
-        repository._delete_by("version", "1.0")
+        task_repository._delete_by("version", "1.0")
 
 
-        assert len(repository._load_all()) == 5
+        assert len(task_repository._load_all()) == 5
 
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_search(self, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_search(self, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node], version="random_version_number")
         task = Task("task_config_id", {}, print, [data_node], [data_node], version="random_version_number")
 
 
         for i in range(10):
         for i in range(10):
             task.id = TaskId(f"task-{i}")
             task.id = TaskId(f"task-{i}")
             task._owner_id = f"owner-{i}"
             task._owner_id = f"owner-{i}"
-            repository._save(task)
+            task_repository._save(task)
 
 
-        assert len(repository._load_all()) == 10
+        assert len(task_repository._load_all()) == 10
 
 
-        objs = repository._search("owner_id", "owner-2")
+        objs = task_repository._search("owner_id", "owner-2")
         assert len(objs) == 1
         assert len(objs) == 1
         assert isinstance(objs[0], Task)
         assert isinstance(objs[0], Task)
 
 
-        objs = repository._search("owner_id", "owner-2", filters=[{"version": "random_version_number"}])
+        objs = task_repository._search("owner_id", "owner-2", filters=[{"version": "random_version_number"}])
         assert len(objs) == 1
         assert len(objs) == 1
         assert isinstance(objs[0], Task)
         assert isinstance(objs[0], Task)
 
 
-        assert repository._search("owner_id", "owner-2", filters=[{"version": "non_existed_version"}]) == []
+        assert task_repository._search("owner_id", "owner-2", filters=[{"version": "non_existed_version"}]) == []
 
 
-    @pytest.mark.parametrize("repo", [_TaskFSRepository, _TaskSQLRepository])
-    def test_export(self, tmpdir, data_node, repo, init_sql_repo):
-        repository = repo()
-        _DataSQLRepository()._save(data_node)
+    @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
+    def test_export(self, tmpdir, data_node, repo, tmp_sqlite):
+        if repo[1] == _DataSQLRepository:
+            Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
+        task_repository, data_repository = repo[0](), repo[1]()
+        data_repository._save(data_node)
         task = Task("task_config_id", {}, print, [data_node], [data_node])
         task = Task("task_config_id", {}, print, [data_node], [data_node])
-        repository._save(task)
+        task_repository._save(task)
 
 
-        repository._export(task.id, tmpdir.strpath)
-        dir_path = repository.dir_path if repo == _TaskFSRepository else os.path.join(tmpdir.strpath, "task")
+        task_repository._export(task.id, tmpdir.strpath)
+        dir_path = task_repository.dir_path if repo[0] == _TaskFSRepository else os.path.join(tmpdir.strpath, "task")
 
 
         assert os.path.exists(os.path.join(dir_path, f"{task.id}.json"))
         assert os.path.exists(os.path.join(dir_path, f"{task.id}.json"))