Selaa lähdekoodia

refactored code, added check only input dns, added further tests

Toan Quach 1 vuosi sitten
vanhempi
säilyke
34c8ad4420

+ 27 - 18
taipy/core/_entity/_ready_to_run_property.py

@@ -9,7 +9,6 @@
 # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations under the License.
 
-from collections import defaultdict
 from typing import TYPE_CHECKING, Dict, Set, Union
 
 from ..notification import EventOperation, Notifier, _make_event
@@ -25,13 +24,11 @@ class _ReadyToRunProperty:
     IS_SUBMITTABLE_PROPERTY_NAME: str = "is_submittable"
 
     # A dictionary of the data nodes not ready_to_read and their corresponding submittable entities.
-    _datanode_id_submittables: Dict["DataNodeId", Set[Union["ScenarioId", "SequenceId", "TaskId"]]] = defaultdict(set)
+    _datanode_id_submittables: Dict["DataNodeId", Set[Union["ScenarioId", "SequenceId", "TaskId"]]] = {}
 
     # A nested dictionary of the submittable entities (Scenario, Sequence, Task) and
     # the data nodes that make it not ready_to_run with the reason(s)
-    _submittable_id_datanodes: Dict[
-        Union["ScenarioId", "SequenceId", "TaskId"], Dict["DataNodeId", Set[str]]
-    ] = defaultdict(lambda: defaultdict(set))
+    _submittable_id_datanodes: Dict[Union["ScenarioId", "SequenceId", "TaskId"], Dict["DataNodeId", Set[str]]] = {}
 
     @classmethod
     def _add(cls, dn: "DataNode", reason: str) -> None:
@@ -42,28 +39,34 @@ class _ReadyToRunProperty:
         parent_entities = dn.get_parents()
 
         for scenario_parent in parent_entities.get(Scenario._MANAGER_NAME, []):
-            _ReadyToRunProperty.__add(scenario_parent, dn, reason)
+            if dn in scenario_parent.get_inputs():
+                _ReadyToRunProperty.__add(scenario_parent, dn, reason)
         for sequence_parent in parent_entities.get(Sequence._MANAGER_NAME, []):
-            _ReadyToRunProperty.__add(sequence_parent, dn, reason)
+            if dn in sequence_parent.get_inputs():
+                _ReadyToRunProperty.__add(sequence_parent, dn, reason)
         for task_parent in parent_entities.get(Task._MANAGER_NAME, []):
-            _ReadyToRunProperty.__add(task_parent, dn, reason)
+            if dn in task_parent.input.values():
+                _ReadyToRunProperty.__add(task_parent, dn, reason)
 
     @classmethod
     def _remove(cls, datanode: "DataNode", reason: str) -> None:
+        from ..taipy import get as tp_get
+
         # check the data node status to determine the reason to be removed
         submittable_ids: Set = cls._datanode_id_submittables.get(datanode.id, set())
 
         to_remove_dn = False
         for submittable_id in submittable_ids:
             # check remove the reason
-            if reason in cls._submittable_id_datanodes[submittable_id].get(datanode.id, set()):
-                cls._submittable_id_datanodes[submittable_id].get(datanode.id, set()).remove(reason)
-            if len(cls._submittable_id_datanodes[submittable_id][datanode.id]) == 0:
+            if reason in cls._submittable_id_datanodes.get(submittable_id, {}).get(datanode.id, set()):
+                cls._submittable_id_datanodes[submittable_id][datanode.id].remove(reason)
+            if len(cls._submittable_id_datanodes.get(submittable_id, {}).get(datanode.id, set())) == 0:
                 to_remove_dn = True
-                cls._submittable_id_datanodes[submittable_id].pop(datanode.id, None)
-                if len(cls._submittable_id_datanodes[submittable_id]) == 0:
-                    from ..taipy import get as tp_get
-
+                cls._submittable_id_datanodes.get(submittable_id, {}).pop(datanode.id, None)
+                if (
+                    submittable_id in cls._submittable_id_datanodes
+                    and len(cls._submittable_id_datanodes[submittable_id]) == 0
+                ):
                     submittable = tp_get(submittable_id)
                     cls.__publish_submittable_property_event(submittable, True)
                     cls._submittable_id_datanodes.pop(submittable_id, None)
@@ -72,12 +75,18 @@ class _ReadyToRunProperty:
             cls._datanode_id_submittables.pop(datanode.id)
 
     @classmethod
-    def __add(
-        cls, submittable: Union["Scenario", "Sequence", "Task"], datanode: "DataNode", reason: str
-    ) -> None:
+    def __add(cls, submittable: Union["Scenario", "Sequence", "Task"], datanode: "DataNode", reason: str) -> None:
+        if datanode.id not in cls._datanode_id_submittables:
+            cls._datanode_id_submittables[datanode.id] = set()
         cls._datanode_id_submittables[datanode.id].add(submittable.id)
+
         if submittable.id not in cls._submittable_id_datanodes:
             cls.__publish_submittable_property_event(submittable, False)
+
+        if submittable.id not in cls._submittable_id_datanodes:
+            cls._submittable_id_datanodes[submittable.id] = {}
+        if datanode.id not in cls._submittable_id_datanodes[submittable.id]:
+            cls._submittable_id_datanodes[submittable.id][datanode.id] = set()
         cls._submittable_id_datanodes[submittable.id][datanode.id].add(reason)
 
     @staticmethod

+ 31 - 7
tests/core/_entity/test_ready_to_run_property.py

@@ -31,11 +31,13 @@ def test_scenario_is_ready_to_run_property():
 
     dn_config_1 = Config.configure_in_memory_data_node("dn_1", 10)
     dn_config_2 = Config.configure_in_memory_data_node("dn_2", 10)
-    task_config = Config.configure_task("task", print, [dn_config_1, dn_config_2])
+    dn_config_3 = Config.configure_in_memory_data_node("dn_3", 10)
+    task_config = Config.configure_task("task", print, [dn_config_1, dn_config_2], [dn_config_3])
     scenario_config = Config.configure_scenario("sc", {task_config}, set(), Frequency.DAILY)
     scenario = scenario_manager._create(scenario_config)
     dn_1 = scenario.dn_1
     dn_2 = scenario.dn_2
+    dn_3 = scenario.dn_3
 
     assert len(scenario_manager._get_all()) == 1
     assert scenario.id not in _ReadyToRunProperty._submittable_id_datanodes
@@ -96,6 +98,12 @@ def test_scenario_is_ready_to_run_property():
     assert scenario_manager._is_submittable(scenario)
     assert scenario_manager._is_submittable(scenario.id)
 
+    dn_3.edit_in_progress = True
+    assert scenario.id not in _ReadyToRunProperty._submittable_id_datanodes
+    assert dn_3.id not in _ReadyToRunProperty._datanode_id_submittables
+    assert scenario_manager._is_submittable(scenario)
+    assert scenario_manager._is_submittable(scenario.id)
+
 
 def test_sequence_is_ready_to_run_property():
     data_manager = _DataManagerFactory._build_manager()
@@ -107,15 +115,18 @@ def test_sequence_is_ready_to_run_property():
     scenario_id = "SCENARIO_scenario_id"
     dn_1 = PickleDataNode("dn_1", Scope.SCENARIO, parent_ids={task_id, scenario_id}, properties={"default_data": 10})
     dn_2 = PickleDataNode("dn_2", Scope.SCENARIO, parent_ids={task_id, scenario_id}, properties={"default_data": 10})
-    task = Task("task", {}, print, [dn_1, dn_2], id=task_id, parent_ids={scenario_id})
+    dn_3 = PickleDataNode("dn_3", Scope.SCENARIO, parent_ids={task_id, scenario_id}, properties={"default_data": 10})
+    task = Task("task", {}, print, [dn_1, dn_2], [dn_3], id=task_id, parent_ids={scenario_id})
     scenario = Scenario("scenario", {task}, {}, set(), scenario_id=scenario_id)
     data_manager._set(dn_1)
     data_manager._set(dn_2)
+    data_manager._set(dn_3)
     task_manager._set(task)
     scenario_manager._set(scenario)
 
     dn_1 = scenario.dn_1
     dn_2 = scenario.dn_2
+    dn_3 = scenario.dn_3
 
     scenario.add_sequences({"sequence": [task]})
     sequence = scenario.sequences["sequence"]
@@ -204,13 +215,19 @@ def test_sequence_is_ready_to_run_property():
     dn_2.edit_in_progress = False
     assert scenario.id not in _ReadyToRunProperty._submittable_id_datanodes
     assert sequence.id not in _ReadyToRunProperty._submittable_id_datanodes
-    assert dn_2.id not in _ReadyToRunProperty._submittable_id_datanodes[scenario.id]
-    assert dn_2.id not in _ReadyToRunProperty._submittable_id_datanodes[sequence.id]
     assert dn_2.id not in _ReadyToRunProperty._datanode_id_submittables
     assert scenario_manager._is_submittable(scenario)
     assert sequence_manager._is_submittable(sequence)
     assert sequence_manager._is_submittable(sequence.id)
 
+    dn_3.edit_in_progress = True
+    assert scenario.id not in _ReadyToRunProperty._submittable_id_datanodes
+    assert sequence.id not in _ReadyToRunProperty._submittable_id_datanodes
+    assert dn_3.id not in _ReadyToRunProperty._datanode_id_submittables
+    assert scenario_manager._is_submittable(scenario)
+    assert sequence_manager._is_submittable(sequence)
+    assert sequence_manager._is_submittable(sequence.id)
+
 
 def test_task_is_ready_to_run_property():
     task_manager = _TaskManagerFactory._build_manager()
@@ -220,13 +237,15 @@ def test_task_is_ready_to_run_property():
 
     dn_config_1 = Config.configure_pickle_data_node("dn_1", default_data=10)
     dn_config_2 = Config.configure_pickle_data_node("dn_2", default_data=15)
-    task_config = Config.configure_task("task", print, [dn_config_1, dn_config_2])
+    dn_config_3 = Config.configure_pickle_data_node("dn_3", default_data=20)
+    task_config = Config.configure_task("task", print, [dn_config_1, dn_config_2], [dn_config_3])
     scenario_config = Config.configure_scenario("scenario", [task_config])
 
     scenario = scenario_manager._create(scenario_config)
     task = scenario.tasks["task"]
     dn_1 = scenario.dn_1
     dn_2 = scenario.dn_2
+    dn_3 = scenario.dn_3
 
     assert len(task_manager._get_all()) == 1
     assert len(scenario_manager._get_all()) == 1
@@ -309,9 +328,14 @@ def test_task_is_ready_to_run_property():
     dn_2.edit_in_progress = False
     assert scenario.id not in _ReadyToRunProperty._submittable_id_datanodes
     assert task.id not in _ReadyToRunProperty._submittable_id_datanodes
-    assert dn_2.id not in _ReadyToRunProperty._submittable_id_datanodes[scenario.id]
-    assert dn_2.id not in _ReadyToRunProperty._submittable_id_datanodes[task.id]
     assert dn_2.id not in _ReadyToRunProperty._datanode_id_submittables
     assert scenario_manager._is_submittable(scenario)
     assert task_manager._is_submittable(task)
     assert task_manager._is_submittable(task.id)
+
+    dn_3.edit_in_progress = True
+    assert scenario.id not in _ReadyToRunProperty._submittable_id_datanodes
+    assert task.id not in _ReadyToRunProperty._submittable_id_datanodes
+    assert scenario_manager._is_submittable(scenario)
+    assert task_manager._is_submittable(task)
+    assert task_manager._is_submittable(task.id)

+ 14 - 20
tests/core/notification/test_events_published.py

@@ -142,16 +142,17 @@ def test_events_published_for_writing_dn():
 
     # Write input manually trigger 4 data node update events
     # for last_edit_date, editor_id, editor_expiration_date and edit_in_progress
+    scenario.the_input.lock_edit()
     scenario.the_input.write("test")
     snapshot = all_evts.capture()
-    assert len(snapshot.collected_events) == 4
+    assert len(snapshot.collected_events) == 13
     assert snapshot.entity_type_collected.get(EventEntityType.CYCLE, 0) == 0
-    assert snapshot.entity_type_collected.get(EventEntityType.DATA_NODE, 0) == 4
-    assert snapshot.entity_type_collected.get(EventEntityType.TASK, 0) == 0
-    assert snapshot.entity_type_collected.get(EventEntityType.SEQUENCE, 0) == 0
-    assert snapshot.entity_type_collected.get(EventEntityType.SCENARIO, 0) == 0
+    assert snapshot.entity_type_collected.get(EventEntityType.DATA_NODE, 0) == 7
+    assert snapshot.entity_type_collected.get(EventEntityType.TASK, 0) == 2
+    assert snapshot.entity_type_collected.get(EventEntityType.SEQUENCE, 0) == 2
+    assert snapshot.entity_type_collected.get(EventEntityType.SCENARIO, 0) == 2
     assert snapshot.operation_collected.get(EventOperation.CREATION, 0) == 0
-    assert snapshot.operation_collected.get(EventOperation.UPDATE, 0) == 4
+    assert snapshot.operation_collected.get(EventOperation.UPDATE, 0) == 13
     all_evts.stop()
 
 
@@ -167,9 +168,6 @@ def test_events_published_for_scenario_submission():
     register_id_0, register_queue_0 = Notifier.register()
     all_evts = RecordingConsumer(register_id_0, register_queue_0)
     all_evts.start()
-    # Before and after writing value to the unwritten data node trigger:
-    # 3 is_submittable update events for the scenario, sequence and task being not submittable
-    # 3 is_submittable update events for the scenario, sequence and task being submittable
     # Submit a scenario triggers:
     # 1 scenario submission event
     # 7 dn update events (for last_edit_date, editor_id(x2), editor_expiration_date(x2) and edit_in_progress(x2))
@@ -182,16 +180,16 @@ def test_events_published_for_scenario_submission():
     scenario.submit()
     snapshot = all_evts.capture()
 
-    assert len(snapshot.collected_events) == 23
+    assert len(snapshot.collected_events) == 17
     assert snapshot.entity_type_collected.get(EventEntityType.CYCLE, 0) == 0
     assert snapshot.entity_type_collected.get(EventEntityType.DATA_NODE, 0) == 7
-    assert snapshot.entity_type_collected.get(EventEntityType.TASK, 0) == 2
-    assert snapshot.entity_type_collected.get(EventEntityType.SEQUENCE, 0) == 2
-    assert snapshot.entity_type_collected.get(EventEntityType.SCENARIO, 0) == 3
+    assert snapshot.entity_type_collected.get(EventEntityType.TASK, 0) == 0
+    assert snapshot.entity_type_collected.get(EventEntityType.SEQUENCE, 0) == 0
+    assert snapshot.entity_type_collected.get(EventEntityType.SCENARIO, 0) == 1
     assert snapshot.entity_type_collected.get(EventEntityType.JOB, 0) == 4
     assert snapshot.entity_type_collected.get(EventEntityType.SUBMISSION, 0) == 5
     assert snapshot.operation_collected.get(EventOperation.CREATION, 0) == 2
-    assert snapshot.operation_collected.get(EventOperation.UPDATE, 0) == 20
+    assert snapshot.operation_collected.get(EventOperation.UPDATE, 0) == 14
     assert snapshot.operation_collected.get(EventOperation.SUBMISSION, 0) == 1
 
     assert snapshot.attr_name_collected["last_edit_date"] == 1
@@ -201,7 +199,6 @@ def test_events_published_for_scenario_submission():
     assert snapshot.attr_name_collected["status"] == 3
     assert snapshot.attr_name_collected["jobs"] == 1
     assert snapshot.attr_name_collected["submission_status"] == 3
-    assert snapshot.attr_name_collected["is_submittable"] == 6
 
     all_evts.stop()
 
@@ -310,13 +307,10 @@ def test_scenario_events():
 
     scenario.submit()
     snapshot = consumer.capture()
-    assert len(snapshot.collected_events) == 2
-    assert snapshot.collected_events[0].operation == EventOperation.UPDATE
+    assert len(snapshot.collected_events) == 1
+    assert snapshot.collected_events[0].operation == EventOperation.SUBMISSION
     assert snapshot.collected_events[0].entity_type == EventEntityType.SCENARIO
     assert snapshot.collected_events[0].metadata.get("config_id") == scenario.config_id
-    assert snapshot.collected_events[1].operation == EventOperation.SUBMISSION
-    assert snapshot.collected_events[1].entity_type == EventEntityType.SCENARIO
-    assert snapshot.collected_events[1].metadata.get("config_id") == scenario.config_id
 
     # Delete scenario
     tp.delete(scenario.id)

+ 71 - 0
tests/core/notification/test_published_ready_to_run_event.py

@@ -0,0 +1,71 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.
+
+
+from taipy.config.config import Config
+from taipy.core.notification.event import EventEntityType, EventOperation
+from taipy.core.notification.notifier import Notifier
+from taipy.core.scenario._scenario_manager_factory import _ScenarioManagerFactory
+from tests.core.notification.test_events_published import RecordingConsumer
+
+
+def empty_fct(inp):
+    return inp
+
+
+def test_published_is_ready_to_run_event():
+    scenario_manager = _ScenarioManagerFactory._build_manager()
+    assert len(scenario_manager._get_all()) == 0
+
+    dn_config_1 = Config.configure_pickle_data_node("dn_1")
+    dn_config_2 = Config.configure_pickle_data_node("dn_2")
+    task_config = Config.configure_task("task", empty_fct, [dn_config_1], [dn_config_2])
+    scenario_config = Config.configure_scenario("sc", {task_config}, set())
+    scenario = scenario_manager._create(scenario_config)
+    scenario.add_sequences({"sequence": [scenario.task]})
+    dn_1 = scenario.dn_1
+    dn_2 = scenario.dn_2
+
+    register_id_0, register_queue_0 = Notifier.register()
+    all_evts = RecordingConsumer(register_id_0, register_queue_0)
+    all_evts.start()
+
+    dn_1.lock_edit()
+    dn_1.write(15)
+
+    snapshot = all_evts.capture()
+
+    assert len(snapshot.collected_events) == 13
+    assert snapshot.entity_type_collected.get(EventEntityType.CYCLE, 0) == 0
+    assert snapshot.entity_type_collected.get(EventEntityType.DATA_NODE, 0) == 7
+    assert snapshot.entity_type_collected.get(EventEntityType.TASK, 0) == 2
+    assert snapshot.entity_type_collected.get(EventEntityType.SEQUENCE, 0) == 2
+    assert snapshot.entity_type_collected.get(EventEntityType.SCENARIO, 0) == 2
+    assert snapshot.operation_collected.get(EventOperation.CREATION, 0) == 0
+    assert snapshot.operation_collected.get(EventOperation.UPDATE, 0) == 13
+    assert snapshot.attr_name_collected["is_submittable"] == 6
+
+    dn_2.write(15)
+    snapshot = all_evts.capture()
+
+    assert len(snapshot.collected_events) == 4
+    assert snapshot.entity_type_collected.get(EventEntityType.CYCLE, 0) == 0
+    assert snapshot.entity_type_collected.get(EventEntityType.DATA_NODE, 0) == 4
+    assert snapshot.entity_type_collected.get(EventEntityType.TASK, 0) == 0
+    assert snapshot.entity_type_collected.get(EventEntityType.SEQUENCE, 0) == 0
+    assert snapshot.entity_type_collected.get(EventEntityType.SCENARIO, 0) == 0
+    assert snapshot.operation_collected.get(EventOperation.CREATION, 0) == 0
+    assert snapshot.operation_collected.get(EventOperation.UPDATE, 0) == 4
+    assert snapshot.attr_name_collected["editor_id"] == 1
+    assert snapshot.attr_name_collected["editor_expiration_date"] == 1
+    assert snapshot.attr_name_collected["edit_in_progress"] == 1
+    assert snapshot.attr_name_collected["last_edit_date"] == 1
+    all_evts.stop()

+ 0 - 2
tests/core/sequence/test_sequence_manager.py

@@ -455,8 +455,6 @@ def test_sequence_notification_subscribe(mocker):
             notify_1,
             notify_1,
             notify_1,
-            notify_1,
-            notify_2,
             notify_2,
             notify_2,
             notify_2,