ソースを参照

added check instance types in equal functions

Toan Quach 1 年間 前
コミット
9f7a5e1d81

+ 1 - 1
taipy/core/cycle/cycle.py

@@ -157,7 +157,7 @@ class Cycle(_Entity, _Labeled):
         raise AttributeError(f"{attribute_name} is not an attribute of cycle {self.id}")
 
     def __eq__(self, other):
-        return self.id == other.id
+        return isinstance(other, Cycle) and self.id == other.id
 
     def __hash__(self):
         return hash(self.id)

+ 1 - 1
taipy/core/data/data_node.py

@@ -277,7 +277,7 @@ class DataNode(_Entity, _Labeled):
         return {key: value for key, value in self.properties.items() if key not in self._TAIPY_PROPERTIES}
 
     def __eq__(self, other):
-        return self.id == other.id
+        return isinstance(other, DataNode) and self.id == other.id
 
     def __ne__(self, other):
         return not self == other

+ 1 - 1
taipy/core/scenario/scenario.py

@@ -136,7 +136,7 @@ class Scenario(_Entity, Submittable, _Labeled):
         return hash(self.id)
 
     def __eq__(self, other):
-        return self.id == other.id
+        return isinstance(other, Scenario) and self.id == other.id
 
     def __getattr__(self, attribute_name):
         protected_attribute_name = _validate_id(attribute_name)

+ 1 - 1
taipy/core/sequence/sequence.py

@@ -81,7 +81,7 @@ class Sequence(_Entity, Submittable, _Labeled):
         return hash(self.id)
 
     def __eq__(self, other):
-        return self.id == other.id
+        return isinstance(other, Sequence) and self.id == other.id
 
     def __getattr__(self, attribute_name):
         protected_attribute_name = _validate_id(attribute_name)

+ 1 - 1
taipy/core/task/task.py

@@ -82,7 +82,7 @@ class Task(_Entity, _Labeled):
         return hash(self.id)
 
     def __eq__(self, other):
-        return self.id == other.id
+        return isinstance(other, Task) and self.id == other.id
 
     def __getstate__(self):
         return vars(self)

+ 8 - 6
tests/core/conftest.py

@@ -44,7 +44,7 @@ from taipy.core.cycle.cycle import Cycle
 from taipy.core.cycle.cycle_id import CycleId
 from taipy.core.data._data_manager_factory import _DataManagerFactory
 from taipy.core.data._data_model import _DataNodeModel
-from taipy.core.data.in_memory import InMemoryDataNode
+from taipy.core.data.in_memory import DataNodeId, InMemoryDataNode
 from taipy.core.job._job_manager_factory import _JobManagerFactory
 from taipy.core.job.job import Job
 from taipy.core.job.job_id import JobId
@@ -59,7 +59,7 @@ from taipy.core.sequence.sequence_id import SequenceId
 from taipy.core.submission._submission_manager_factory import _SubmissionManagerFactory
 from taipy.core.submission.submission import Submission
 from taipy.core.task._task_manager_factory import _TaskManagerFactory
-from taipy.core.task.task import Task
+from taipy.core.task.task import Task, TaskId
 
 current_time = datetime.now()
 
@@ -188,7 +188,7 @@ def scenario(cycle):
         set(),
         {},
         set(),
-        ScenarioId("sc_id"),
+        ScenarioId("SCENARIO_scenario_id"),
         current_time,
         is_primary=False,
         tags={"foo"},
@@ -199,7 +199,9 @@ def scenario(cycle):
 
 @pytest.fixture(scope="function")
 def data_node():
-    return InMemoryDataNode("data_node_config_id", Scope.SCENARIO, version="random_version_number")
+    return InMemoryDataNode(
+        "data_node_config_id", Scope.SCENARIO, version="random_version_number", id=DataNodeId("DATANODE_data_node_id")
+    )
 
 
 @pytest.fixture(scope="function")
@@ -225,7 +227,7 @@ def data_node_model():
 @pytest.fixture(scope="function")
 def task(data_node):
     dn = InMemoryDataNode("dn_config_id", Scope.SCENARIO, version="random_version_number")
-    return Task("task_config_id", {}, print, [data_node], [dn])
+    return Task("task_config_id", {}, print, [data_node], [dn], TaskId("TASK_task_id"))
 
 
 @pytest.fixture(scope="function")
@@ -255,7 +257,7 @@ def cycle():
         start_date=example_date,
         end_date=example_date,
         name="cc",
-        id=CycleId("cc_id"),
+        id=CycleId("CYCLE_cycle_id"),
     )
 
 

+ 17 - 0
tests/core/cycle/test_cycle.py

@@ -14,7 +14,24 @@ from datetime import timedelta
 from taipy.config.common.frequency import Frequency
 from taipy.core import CycleId
 from taipy.core.cycle._cycle_manager import _CycleManager
+from taipy.core.cycle._cycle_manager_factory import _CycleManagerFactory
 from taipy.core.cycle.cycle import Cycle
+from taipy.core.task.task import Task
+
+
+def test_cycle_equals(cycle):
+    cycle_manager = _CycleManagerFactory()._build_manager()
+
+    cycle_id = cycle.id
+    cycle_manager._set(cycle)
+
+    # To test if instance is same type
+    task = Task("task", {}, print, [], [], cycle_id)
+
+    cycle_2 = cycle_manager._get(cycle_id)
+    assert cycle == cycle_2
+    assert cycle != cycle_id
+    assert cycle != task
 
 
 def test_create_cycle_entity(current_datetime):

+ 16 - 0
tests/core/data/test_data_node.py

@@ -21,11 +21,13 @@ from taipy.config import Config
 from taipy.config.common.scope import Scope
 from taipy.config.exceptions.exceptions import InvalidConfigurationId
 from taipy.core.data._data_manager import _DataManager
+from taipy.core.data._data_manager_factory import _DataManagerFactory
 from taipy.core.data.data_node import DataNode
 from taipy.core.data.data_node_id import DataNodeId
 from taipy.core.data.in_memory import InMemoryDataNode
 from taipy.core.exceptions.exceptions import DataNodeIsBeingEdited, NoData
 from taipy.core.job.job_id import JobId
+from taipy.core.task.task import Task
 
 from .utils import FakeDataNode
 
@@ -46,6 +48,20 @@ def funct_b_d(input: str):
 
 
 class TestDataNode:
+    def test_dn_equals(self, data_node):
+        data_manager = _DataManagerFactory()._build_manager()
+
+        dn_id = data_node.id
+        data_manager._set(data_node)
+
+        # # To test if instance is same type
+        task = Task("task", {}, print, [], [], dn_id)
+
+        dn_2 = data_manager._get(dn_id)
+        assert data_node == dn_2
+        assert data_node != dn_id
+        assert data_node != task
+
     def test_create_with_default_values(self):
         dn = DataNode("foo_bar")
         assert dn.config_id == "foo_bar"

+ 15 - 0
tests/core/scenario/test_scenario.py

@@ -32,6 +32,21 @@ from taipy.core.task._task_manager_factory import _TaskManagerFactory
 from taipy.core.task.task import Task, TaskId
 
 
+def test_scenario_equals(scenario):
+    scenario_manager = _ScenarioManagerFactory()._build_manager()
+
+    scenario_id = scenario.id
+    scenario_manager._set(scenario)
+
+    # To test if instance is same type
+    task = Task("task", {}, print, [], [], scenario_id)
+
+    scenario_2 = scenario_manager._get(scenario_id)
+    assert scenario == scenario_2
+    assert scenario != scenario_id
+    assert scenario != task
+
+
 def test_create_primary_scenario(cycle):
     scenario = Scenario("foo", set(), {"key": "value"}, is_primary=True, cycle=cycle)
     assert scenario.id is not None

+ 20 - 0
tests/core/sequence/test_sequence.py

@@ -13,6 +13,7 @@ from unittest import mock
 
 import pytest
 
+from taipy.config import Config
 from taipy.config.common.scope import Scope
 from taipy.core.common._utils import _Subscriber
 from taipy.core.data._data_manager_factory import _DataManagerFactory
@@ -28,6 +29,25 @@ from taipy.core.task._task_manager import _TaskManager
 from taipy.core.task.task import Task, TaskId
 
 
+def test_sequence_equals():
+    task_config = Config.configure_task("mult_by_3", print, [], None)
+    scenario_config = Config.configure_scenario("scenario", [task_config])
+
+    scenario = _ScenarioManager._create(scenario_config)
+    scenario.add_sequences({"print": list(scenario.tasks.values())})
+    sequence_1 = scenario.sequences["print"]
+    sequence_id = sequence_1.id
+
+    assert sequence_1.name == "print"
+    sequence_2 = _SequenceManager._get(sequence_id)
+    # To test if instance is same type
+    task = Task("task", {}, print, [], [], sequence_id)
+
+    assert sequence_1 == sequence_2
+    assert sequence_1 != sequence_id
+    assert sequence_1 != task
+
+
 def test_create_sequence():
     input = InMemoryDataNode("foo", Scope.SCENARIO)
     output = InMemoryDataNode("bar", Scope.SCENARIO)

+ 16 - 0
tests/core/task/test_task.py

@@ -22,6 +22,7 @@ from taipy.core.data.csv import CSVDataNode
 from taipy.core.data.data_node import DataNode
 from taipy.core.data.in_memory import InMemoryDataNode
 from taipy.core.task._task_manager import _TaskManager
+from taipy.core.task._task_manager_factory import _TaskManagerFactory
 from taipy.core.task.task import Task
 
 
@@ -45,6 +46,21 @@ def input_config():
     return [DataNodeConfig("input_name_1"), DataNodeConfig("input_name_2"), DataNodeConfig("input_name_3")]
 
 
+def test_task_equals(task):
+    task_manager = _TaskManagerFactory()._build_manager()
+
+    task_id = task.id
+    task_manager._set(task)
+
+    # To test if instance is same type
+    dn = CSVDataNode("foo_bar", Scope.SCENARIO, task_id)
+
+    task_2 = task_manager._get(task_id)
+    assert task == task_2
+    assert task != task_id
+    assert task != dn
+
+
 def test_create_task():
     name = "name_1"
     task = Task(name, {}, print, [], [])