Browse Source

Merge pull request #620 from Avaiga/feature/#411-raise-error-if-config-id-is-entity-attribute

Feature/#411 - Raise error if config_id is an entity attribute
Đỗ Trường Giang 1 year ago
parent
commit
4d40830959

+ 33 - 1
taipy/core/config/checkers/_data_node_config_checker.py

@@ -10,13 +10,15 @@
 # specific language governing permissions and limitations under the License.
 
 from datetime import timedelta
-from typing import Dict
+from typing import Dict, List
 
 from taipy.config._config import _Config
 from taipy.config.checker._checker import _ConfigChecker
 from taipy.config.checker.issue_collector import IssueCollector
 from taipy.config.common.scope import Scope
 
+from ...scenario.scenario import Scenario
+from ...task.task import Task
 from ..data_node_config import DataNodeConfig
 
 
@@ -26,9 +28,17 @@ class _DataNodeConfigChecker(_ConfigChecker):
 
     def _check(self) -> IssueCollector:
         data_node_configs: Dict[str, DataNodeConfig] = self._config._sections[DataNodeConfig.name]
+        task_attributes = [attr for attr in dir(Task) if not callable(getattr(Task, attr)) and not attr.startswith("_")]
+        scenario_attributes = [
+            attr for attr in dir(Scenario) if not callable(getattr(Scenario, attr)) and not attr.startswith("_")
+        ]
+
         for data_node_config_id, data_node_config in data_node_configs.items():
             self._check_existing_config_id(data_node_config)
             self._check_if_entity_property_key_used_is_predefined(data_node_config)
+            self._check_if_config_id_is_overlapping_with_task_and_scenario_attributes(
+                data_node_config_id, data_node_config, task_attributes, scenario_attributes
+            )
             self._check_storage_type(data_node_config_id, data_node_config)
             self._check_scope(data_node_config_id, data_node_config)
             self._check_validity_period(data_node_config_id, data_node_config)
@@ -38,6 +48,28 @@ class _DataNodeConfigChecker(_ConfigChecker):
             self._check_exposed_type(data_node_config_id, data_node_config)
         return self._collector
 
+    def _check_if_config_id_is_overlapping_with_task_and_scenario_attributes(
+        self,
+        data_node_config_id: str,
+        data_node_config: DataNodeConfig,
+        task_attributes: List[str],
+        scenario_attributes: List[str],
+    ):
+        if data_node_config.id in task_attributes:
+            self._error(
+                data_node_config._ID_KEY,
+                data_node_config.id,
+                f"The id of the DataNodeConfig `{data_node_config_id}` is overlapping with the "
+                f"attribute `{data_node_config.id}` of a Task entity.",
+            )
+        elif data_node_config.id in scenario_attributes:
+            self._error(
+                data_node_config._ID_KEY,
+                data_node_config.id,
+                f"The id of the DataNodeConfig `{data_node_config_id}` is overlapping with the "
+                f"attribute `{data_node_config.id}` of a Scenario entity.",
+            )
+
     def _check_storage_type(self, data_node_config_id: str, data_node_config: DataNodeConfig):
         if data_node_config.storage_type not in DataNodeConfig._ALL_STORAGE_TYPES:
             self._error(

+ 21 - 0
taipy/core/config/checkers/_task_config_checker.py

@@ -9,10 +9,13 @@
 # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations under the License.
 
+from typing import List
+
 from taipy.config._config import _Config
 from taipy.config.checker._checkers._config_checker import _ConfigChecker
 from taipy.config.checker.issue_collector import IssueCollector
 
+from ...scenario.scenario import Scenario
 from ..data_node_config import DataNodeConfig
 from ..task_config import TaskConfig
 
@@ -23,15 +26,33 @@ class _TaskConfigChecker(_ConfigChecker):
 
     def _check(self) -> IssueCollector:
         task_configs = self._config._sections[TaskConfig.name]
+        scenario_attributes = [
+            attr for attr in dir(Scenario) if not callable(getattr(Scenario, attr)) and not attr.startswith("_")
+        ]
+
         for task_config_id, task_config in task_configs.items():
             if task_config_id != _Config.DEFAULT_KEY:
                 self._check_existing_config_id(task_config)
                 self._check_if_entity_property_key_used_is_predefined(task_config)
+                self._check_if_config_id_is_overlapping_with_scenario_attributes(
+                    task_config_id, task_config, scenario_attributes
+                )
                 self._check_existing_function(task_config_id, task_config)
                 self._check_inputs(task_config_id, task_config)
                 self._check_outputs(task_config_id, task_config)
         return self._collector
 
+    def _check_if_config_id_is_overlapping_with_scenario_attributes(
+        self, task_config_id: str, task_config: TaskConfig, scenario_attributes: List[str]
+    ):
+        if task_config.id in scenario_attributes:
+            self._error(
+                task_config._ID_KEY,
+                task_config.id,
+                f"The id of the TaskConfig `{task_config_id}` is overlapping with the "
+                f"attribute `{task_config.id}` of a Scenario entity.",
+            )
+
     def _check_inputs(self, task_config_id: str, task_config: TaskConfig):
         self._check_children(
             TaskConfig, task_config_id, task_config._INPUT_KEY, task_config.input_configs, DataNodeConfig

+ 11 - 2
taipy/core/data/data_node.py

@@ -16,6 +16,7 @@ from datetime import datetime, timedelta
 from typing import Any, Dict, List, Optional, Set, Tuple, Union
 
 import networkx as nx
+
 from taipy.config.common._validate_id import _validate_id
 from taipy.config.common.scope import Scope
 from taipy.logger._taipy_logger import _TaipyLogger
@@ -103,9 +104,9 @@ class DataNode(_Entity, _Labeled):
         editor_expiration_date: Optional[datetime] = None,
         **kwargs,
     ):
-        self.config_id = _validate_id(config_id)
+        self._config_id = _validate_id(config_id)
         self.id = id or DataNodeId(self.__ID_SEPARATOR.join([self._ID_PREFIX, self.config_id, str(uuid.uuid4())]))
-        self.owner_id = owner_id
+        self._owner_id = owner_id
         self._parent_ids = parent_ids or set()
         self._scope = scope
         self._last_edit_date = last_edit_date
@@ -120,6 +121,14 @@ class DataNode(_Entity, _Labeled):
 
         self._properties = _Properties(self, **kwargs)
 
+    @property
+    def config_id(self):
+        return self._config_id
+
+    @property
+    def owner_id(self):
+        return self._owner_id
+
     def get_parents(self):
         """Get all parents of this data node."""
         from ... import core as tp

+ 10 - 3
taipy/core/scenario/scenario.py

@@ -17,6 +17,7 @@ from datetime import datetime
 from typing import Any, Callable, Dict, List, Optional, Set, Union
 
 import networkx as nx
+
 from taipy.config.common._template_handler import _TemplateHandler as _tpl
 from taipy.config.common._validate_id import _validate_id
 
@@ -29,7 +30,6 @@ from .._version._version_manager_factory import _VersionManagerFactory
 from ..common._listattributes import _ListAttributes
 from ..common._utils import _Subscriber
 from ..cycle.cycle import Cycle
-from ..data._data_manager_factory import _DataManagerFactory
 from ..data.data_node import DataNode
 from ..data.data_node_id import DataNodeId
 from ..exceptions.exceptions import (
@@ -42,7 +42,6 @@ from ..exceptions.exceptions import (
 from ..job.job import Job
 from ..notification import Event, EventEntityType, EventOperation, Notifier, _make_event
 from ..sequence.sequence import Sequence
-from ..task._task_manager_factory import _TaskManagerFactory
 from ..task.task import Task
 from ..task.task_id import TaskId
 from .scenario_id import ScenarioId
@@ -96,7 +95,7 @@ class Scenario(_Entity, Submittable, _Labeled):
         sequences: Optional[Dict[str, Dict]] = None,
     ):
         super().__init__(subscribers or [])
-        self.config_id = _validate_id(config_id)
+        self._config_id = _validate_id(config_id)
         self.id: ScenarioId = scenario_id or self._new_id(self.config_id)
 
         self._tasks: Union[Set[TaskId], Set[Task], Set] = tasks or set()
@@ -156,6 +155,10 @@ class Scenario(_Entity, Submittable, _Labeled):
             return data_nodes[protected_attribute_name]
         raise AttributeError(f"{attribute_name} is not an attribute of scenario {self.id}")
 
+    @property
+    def config_id(self):
+        return self._config_id
+
     @property  # type: ignore
     @_self_reload(_MANAGER_NAME)
     def sequences(self) -> Dict[str, Sequence]:
@@ -305,6 +308,8 @@ class Scenario(_Entity, Submittable, _Labeled):
         return self.__get_tasks()
 
     def __get_tasks(self) -> Dict[str, Task]:
+        from ..task._task_manager_factory import _TaskManagerFactory
+
         _tasks = {}
         task_manager = _TaskManagerFactory._build_manager()
 
@@ -327,6 +332,8 @@ class Scenario(_Entity, Submittable, _Labeled):
         return self.__get_additional_data_nodes()
 
     def __get_additional_data_nodes(self):
+        from ..data._data_manager_factory import _DataManagerFactory
+
         additional_data_nodes = {}
         data_manager = _DataManagerFactory._build_manager()
 

+ 5 - 1
taipy/core/sequence/sequence.py

@@ -66,7 +66,7 @@ class Sequence(_Entity, Submittable, _Labeled):
         super().__init__(subscribers)
         self.id: SequenceId = sequence_id
         self._tasks = tasks
-        self.owner_id = owner_id
+        self._owner_id = owner_id
         self._parent_ids = parent_ids or set()
         self._properties = _Properties(self, **properties)
         self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
@@ -118,6 +118,10 @@ class Sequence(_Entity, Submittable, _Labeled):
     def parent_ids(self):
         return self._parent_ids
 
+    @property
+    def owner_id(self):
+        return self._owner_id
+
     @property
     def version(self):
         return self._version

+ 10 - 4
taipy/core/task/task.py

@@ -21,9 +21,7 @@ from .._entity._labeled import _Labeled
 from .._entity._properties import _Properties
 from .._entity._reload import _Reloader, _self_reload, _self_setter
 from .._version._version_manager_factory import _VersionManagerFactory
-from ..data._data_manager_factory import _DataManagerFactory
 from ..data.data_node import DataNode
-from ..exceptions.exceptions import NonExistingDataNode
 from ..notification.event import Event, EventEntityType, EventOperation, _make_event
 from .task_id import TaskId
 
@@ -71,9 +69,9 @@ class Task(_Entity, _Labeled):
         version: Optional[str] = None,
         skippable: bool = False,
     ):
-        self.config_id = _validate_id(config_id)
+        self._config_id = _validate_id(config_id)
         self.id = id or TaskId(self.__ID_SEPARATOR.join([self._ID_PREFIX, self.config_id, str(uuid.uuid4())]))
-        self.owner_id = owner_id
+        self._owner_id = owner_id
         self._parent_ids = parent_ids or set()
         self.__input = {dn.config_id: dn for dn in input or []}
         self.__output = {dn.config_id: dn for dn in output or []}
@@ -109,6 +107,14 @@ class Task(_Entity, _Labeled):
         self._properties = _Reloader()._reload(self._MANAGER_NAME, self)._properties
         return self._properties
 
+    @property
+    def config_id(self):
+        return self._config_id
+
+    @property
+    def owner_id(self):
+        return self._owner_id
+
     def get_parents(self):
         """Get parents of the task."""
         from ... import core as tp

+ 2 - 1
tests/core/_orchestrator/_dispatcher/test_job_dispatcher.py

@@ -16,6 +16,7 @@ from unittest import mock
 from unittest.mock import MagicMock
 
 from pytest import raises
+
 from taipy.config.config import Config
 from taipy.core import DataNodeId, JobId, TaskId
 from taipy.core._orchestrator._dispatcher._development_job_dispatcher import _DevelopmentJobDispatcher
@@ -146,7 +147,7 @@ def test_exception_in_writing_data():
     job_id = JobId("id1")
     output = MagicMock()
     output.id = DataNodeId("output_id")
-    output.config_id = "my_raising_datanode"
+    output._config_id = "my_raising_datanode"
     output._is_in_cache = False
     output.write.side_effect = ValueError()
     task = Task(config_id="name", properties={}, input=[], function=print, output=[output], id=task_id)

+ 42 - 0
tests/core/config/checkers/test_data_node_config_checker.py

@@ -45,6 +45,48 @@ class TestDataNodeConfigChecker:
         Config.check()
         assert len(Config._collector.errors) == 0
 
+    def test_check_config_id_is_different_from_task_and_scenario_attributes(self, caplog):
+        Config._collector = IssueCollector()
+        config = Config._applied_config
+        Config._compile_configs()
+        Config.check()
+        assert len(Config._collector.errors) == 0
+
+        config._sections[DataNodeConfig.name]["new"] = copy(config._sections[DataNodeConfig.name]["default"])
+
+        for conflict_id in [
+            "function",
+            "input",
+            "output",
+            "parent_ids",
+            "scope",
+            "skippable",
+            "additional_data_nodes",
+            "config_id",
+            "creation_date",
+            "cycle",
+            "data_nodes",
+            "is_primary",
+            "name",
+            "owner_id",
+            "properties",
+            "sequences",
+            "subscribers",
+            "tags",
+            "tasks",
+            "version",
+        ]:
+            config._sections[DataNodeConfig.name]["new"].id = conflict_id
+
+            with pytest.raises(SystemExit):
+                Config._collector = IssueCollector()
+                Config.check()
+            assert len(Config._collector.errors) == 1
+            expected_error_message = (
+                f"The id of the DataNodeConfig `new` is overlapping with the attribute `{conflict_id}` of a"
+            )
+            assert expected_error_message in caplog.text
+
     def test_check_if_entity_property_key_used_is_predefined(self, caplog):
         Config._collector = IssueCollector()
         config = Config._applied_config

+ 38 - 0
tests/core/config/checkers/test_task_config_checker.py

@@ -12,6 +12,7 @@
 from copy import copy
 
 import pytest
+
 from taipy.config.checker.issue_collector import IssueCollector
 from taipy.config.config import Config
 from taipy.core.config import TaskConfig
@@ -48,6 +49,43 @@ class TestTaskConfigChecker:
         assert len(Config._collector.errors) == 1
         assert len(Config._collector.warnings) == 2
 
+    def test_check_config_id_is_different_from_all_task_properties(self, caplog):
+        Config._collector = IssueCollector()
+        config = Config._applied_config
+        Config._compile_configs()
+        Config.check()
+        assert len(Config._collector.errors) == 0
+
+        config._sections[TaskConfig.name]["new"] = copy(config._sections[TaskConfig.name]["default"])
+
+        for conflict_id in [
+            "additional_data_nodes",
+            "config_id",
+            "creation_date",
+            "cycle",
+            "data_nodes",
+            "is_primary",
+            "name",
+            "owner_id",
+            "properties",
+            "sequences",
+            "subscribers",
+            "tags",
+            "tasks",
+            "version",
+        ]:
+            config._sections[TaskConfig.name]["new"].id = conflict_id
+
+            with pytest.raises(SystemExit):
+                Config._collector = IssueCollector()
+                Config.check()
+            assert len(Config._collector.errors) == 2
+            expected_error_message = (
+                "The id of the TaskConfig `new` is overlapping with the attribute"
+                f" `{conflict_id}` of a Scenario entity."
+            )
+            assert expected_error_message in caplog.text
+
     def test_check_if_entity_property_key_used_is_predefined(self, caplog):
         Config._collector = IssueCollector()
         config = Config._applied_config

+ 2 - 2
tests/core/data/test_data_manager.py

@@ -12,6 +12,7 @@ import os
 import pathlib
 
 import pytest
+
 from taipy.config.common.scope import Scope
 from taipy.config.config import Config
 from taipy.core._version._version_manager import _VersionManager
@@ -22,7 +23,6 @@ from taipy.core.data.data_node_id import DataNodeId
 from taipy.core.data.in_memory import InMemoryDataNode
 from taipy.core.data.pickle import PickleDataNode
 from taipy.core.exceptions.exceptions import InvalidDataNodeType, ModelNotFound
-
 from tests.core.utils.named_temporary_file import NamedTemporaryFile
 
 
@@ -346,7 +346,7 @@ class TestDataManager:
         assert _DataManager._exists(dn.id)
 
         # changing data node attribute
-        dn.config_id = "foo"
+        dn._config_id = "foo"
         assert dn.config_id == "foo"
         _DataManager._set(dn)
         assert len(_DataManager._get_all()) == 1

+ 2 - 1
tests/core/data/test_data_manager_with_sql_repo.py

@@ -13,6 +13,7 @@ import os
 import pathlib
 
 import pytest
+
 from taipy.config.common.scope import Scope
 from taipy.config.config import Config
 from taipy.core._version._version_manager import _VersionManager
@@ -162,7 +163,7 @@ class TestDataManager:
         assert _DataManager._exists(dn.id)
 
         # changing data node attribute
-        dn.config_id = "foo"
+        dn._config_id = "foo"
         assert dn.config_id == "foo"
         _DataManager._set(dn)
         assert len(_DataManager._get_all()) == 1

+ 2 - 2
tests/core/data/test_data_repositories.py

@@ -52,7 +52,7 @@ class TestDataNodeRepository:
 
         for i in range(10):
             data_node.id = DataNodeId(f"data_node-{i}")
-            data_node.owner_id = f"task-{i}"
+            data_node._owner_id = f"task-{i}"
             repository._save(data_node)
         objs = repository._load_all(filters=[{"owner_id": "task-2"}])
 
@@ -119,7 +119,7 @@ class TestDataNodeRepository:
 
         for i in range(10):
             data_node.id = DataNodeId(f"data_node-{i}")
-            data_node.owner_id = f"task-{i}"
+            data_node._owner_id = f"task-{i}"
             repository._save(data_node)
 
         assert len(repository._load_all()) == 10

+ 1 - 1
tests/core/scenario/test_scenario.py

@@ -628,7 +628,7 @@ def test_auto_set_and_reload(cycle, current_datetime, task, data_node):
         assert scenario.properties["temp_key_5"] == 0
 
         new_datetime_2 = new_datetime + timedelta(5)
-        scenario.config_id = "foo"
+        scenario._config_id = "foo"
         scenario.tasks = set()
         scenario.additional_data_nodes = set()
         scenario.remove_sequences([sequence_1_name])

+ 3 - 2
tests/core/task/test_task_repositories.py

@@ -12,6 +12,7 @@
 import os
 
 import pytest
+
 from taipy.core.data._data_sql_repository import _DataSQLRepository
 from taipy.core.exceptions import ModelNotFound
 from taipy.core.task._task_fs_repository import _TaskFSRepository
@@ -63,7 +64,7 @@ class TestTaskFSRepository:
 
         for i in range(10):
             task.id = TaskId(f"task-{i}")
-            task.owner_id = f"owner-{i}"
+            task._owner_id = f"owner-{i}"
             repository._save(task)
         objs = repository._load_all(filters=[{"owner_id": "owner-2"}])
 
@@ -140,7 +141,7 @@ class TestTaskFSRepository:
 
         for i in range(10):
             task.id = TaskId(f"task-{i}")
-            task.owner_id = f"owner-{i}"
+            task._owner_id = f"owner-{i}"
             repository._save(task)
 
         assert len(repository._load_all()) == 10

+ 4 - 3
tests/core/test_complex_application.py

@@ -15,6 +15,7 @@ from time import sleep
 from unittest.mock import patch
 
 import pandas as pd
+
 import taipy.core.taipy as tp
 from taipy.config import Config
 from taipy.core import Core, Status
@@ -72,9 +73,9 @@ def return_a_number_with_sleep():
 def test_skipped_jobs():
     Config.configure_job_executions(mode=JobConfig._DEVELOPMENT_MODE)
     _OrchestratorFactory._build_orchestrator()
-    input_config = Config.configure_data_node("input")
+    input_config = Config.configure_data_node("input_dn")
     intermediate_config = Config.configure_data_node("intermediate")
-    output_config = Config.configure_data_node("output")
+    output_config = Config.configure_data_node("output_dn")
     task_config_1 = Config.configure_task("first", mult_by_2, input_config, intermediate_config, skippable=True)
     task_config_2 = Config.configure_task("second", mult_by_2, intermediate_config, output_config, skippable=True)
     scenario_config = Config.configure_scenario("scenario", [task_config_1, task_config_2])
@@ -84,7 +85,7 @@ def test_skipped_jobs():
         core.run()
 
         scenario = tp.create_scenario(scenario_config)
-        scenario.input.write(2)
+        scenario.input_dn.write(2)
         scenario.submit()
         assert len(tp.get_jobs()) == 2
         for job in tp.get_jobs():