瀏覽代碼

fix context manager with sequence management

jrobinAV 1 年之前
父節點
當前提交
a71e7b41a5
共有 2 個文件被更改,包括 29 次插入5 次删除
  1. 6 5
      taipy/core/scenario/scenario.py
  2. 23 0
      tests/core/scenario/test_scenario.py

+ 6 - 5
taipy/core/scenario/scenario.py

@@ -198,8 +198,8 @@ class Scenario(_Entity, Submittable, _Labeled):
         """
         if name in self.sequences:
             raise SequenceAlreadyExists(name, self.id)
-        self._set_sequence(name, tasks, properties, subscribers)
-        Notifier.publish(_make_event(self.sequences[name], EventOperation.CREATION))
+        seq = self._set_sequence(name, tasks, properties, subscribers)
+        Notifier.publish(_make_event(seq, EventOperation.CREATION))
 
     def update_sequence(
         self,
@@ -222,8 +222,8 @@ class Scenario(_Entity, Submittable, _Labeled):
         """
         if name not in self.sequences:
             raise NonExistingSequence(name, self.id)
-        self._set_sequence(name, tasks, properties, subscribers)
-        Notifier.publish(_make_event(self.sequences[name], EventOperation.UPDATE))
+        seq = self._set_sequence(name, tasks, properties, subscribers)
+        Notifier.publish(_make_event(seq, EventOperation.UPDATE))
 
     def _set_sequence(
         self,
@@ -231,7 +231,7 @@ class Scenario(_Entity, Submittable, _Labeled):
         tasks: Union[List[Task], List[TaskId]],
         properties: Optional[Dict] = None,
         subscribers: Optional[List[_Subscriber]] = None,
-    ):
+    ) -> Sequence:
         _scenario = _Reloader()._reload(self._MANAGER_NAME, self)
         _scenario_task_ids = set(task.id if isinstance(task, Task) else task for task in _scenario._tasks)
         _sequence_task_ids: Set[TaskId] = set(task.id if isinstance(task, Task) else task for task in tasks)
@@ -253,6 +253,7 @@ class Scenario(_Entity, Submittable, _Labeled):
             }
         )
         self.sequences = _sequences  # type: ignore
+        return seq
 
     def add_sequences(self, sequences: Dict[str, Union[List[Task], List[TaskId]]]):
         """Add multiple sequences to the scenario.

+ 23 - 0
tests/core/scenario/test_scenario.py

@@ -446,6 +446,29 @@ def test_update_sequence(data_node):
     assert scenario.sequences["seq_1"].properties["new_key"] == "new_value"
 
 
+def test_add_rename_and_remove_sequences_within_context(data_node):
+    task_1 = Task("task_1", {}, print, output=[data_node])
+    task_2 = Task("task_2", {}, print, input=[data_node])
+    _TaskManagerFactory._build_manager()._set(task_1)
+    scenario = Scenario(config_id="scenario", tasks={task_1, task_2}, properties={})
+    _ScenarioManagerFactory._build_manager()._set(scenario)
+
+    with scenario as sc:
+        sc.add_sequence("seq_name", [task_1])
+    assert len(scenario.sequences) == 1
+    assert scenario.sequences["seq_name"].tasks == {"task_1": task_1}
+
+    with scenario as sc:
+        sc.update_sequence("seq_name", [task_2])
+    assert len(scenario.sequences) == 1
+    assert scenario.sequences["seq_name"].tasks == {"task_2": task_2}
+
+    with scenario as sc:
+        sc.remove_sequence("seq_name")
+    assert len(scenario.sequences) == 0
+
+
+
 def test_add_property_to_scenario():
     scenario = Scenario("foo", set(), {"key": "value"})
     assert scenario.properties == {"key": "value"}