Selaa lähdekoodia

fix insufficient permissions on enterprise due to scenario get_sequence is based on _create in SequenceManager

Toan Quach 11 kuukautta sitten
vanhempi
säilyke
cc73b3a2c1

+ 3 - 6
taipy/core/scenario/_scenario_manager.py

@@ -38,7 +38,6 @@ from ..exceptions.exceptions import (
     ImportScenarioDoesntHaveAVersion,
     InsufficientScenarioToCompare,
     InvalidScenario,
-    InvalidSequence,
     NonExistingComparator,
     NonExistingScenario,
     NonExistingScenarioConfig,
@@ -190,11 +189,9 @@ class _ScenarioManager(_Manager[Scenario], _VersionMixin):
         if not scenario._is_consistent():
             raise InvalidScenario(scenario.id)
 
-        actual_sequences = scenario._get_sequences()
-        for sequence_name in sequences.keys():
-            if not actual_sequences[sequence_name]._is_consistent():
-                raise InvalidSequence(actual_sequences[sequence_name].id)
-            Notifier.publish(_make_event(actual_sequences[sequence_name], EventOperation.CREATION))
+        from ..sequence._sequence_manager_factory import _SequenceManagerFactory
+
+        _SequenceManagerFactory._build_manager()._bulk_create_from_scenario(scenario)
 
         Notifier.publish(_make_event(scenario, EventOperation.CREATION))
         return scenario

+ 4 - 5
taipy/core/scenario/scenario.py

@@ -268,12 +268,11 @@ class Scenario(_Entity, Submittable, _Labeled):
         _scenario_task_ids = {task.id if isinstance(task, Task) else task for task in _scenario._tasks}
         _sequence_task_ids: Set[TaskId] = {task.id if isinstance(task, Task) else task for task in tasks}
         self.__check_sequence_tasks_exist_in_scenario_tasks(name, _sequence_task_ids, self.id, _scenario_task_ids)
+
         from taipy.core.sequence._sequence_manager_factory import _SequenceManagerFactory
 
         seq_manager = _SequenceManagerFactory._build_manager()
         seq = seq_manager._create(name, tasks, subscribers or [], properties or {}, self.id, self.version)
-        if not seq._is_consistent():
-            raise InvalidSequence(name)
 
         _sequences = _Reloader()._reload(self._MANAGER_NAME, self)._sequences
         _sequences.update(
@@ -391,7 +390,7 @@ class Scenario(_Entity, Submittable, _Labeled):
         sequence_manager = _SequenceManagerFactory._build_manager()
 
         for sequence_name, sequence_data in self._sequences.items():
-            p = sequence_manager._create(
+            sequence = sequence_manager._build_sequence(
                 sequence_name,
                 sequence_data.get(self._SEQUENCE_TASKS_KEY, []),
                 sequence_data.get(self._SEQUENCE_SUBSCRIBERS_KEY, []),
@@ -399,9 +398,9 @@ class Scenario(_Entity, Submittable, _Labeled):
                 self.id,
                 self.version,
             )
-            if not isinstance(p, Sequence):
+            if not isinstance(sequence, Sequence):
                 raise NonExistingSequence(sequence_name, self.id)
-            _sequences[sequence_name] = p
+            _sequences[sequence_name] = sequence
         return _sequences
 
     @property  # type: ignore

+ 59 - 13
taipy/core/sequence/_sequence_manager.py

@@ -21,6 +21,7 @@ from ..common._utils import _Subscriber
 from ..common.reason import Reason
 from ..common.warn_if_inputs_not_ready import _warn_if_inputs_not_ready
 from ..exceptions.exceptions import (
+    InvalidSequence,
     InvalidSequenceId,
     ModelNotFound,
     NonExistingSequence,
@@ -137,18 +138,8 @@ class _SequenceManager(_Manager[Sequence], _VersionMixin):
             cls._logger.error(f"Sequence {sequence.id} belongs to a non-existing Scenario {scenario_id}.")
             raise SequenceBelongsToNonExistingScenario(sequence.id, scenario_id)
 
-    @classmethod
-    def _create(
-        cls,
-        sequence_name: str,
-        tasks: Union[List[Task], List[TaskId]],
-        subscribers: Optional[List[_Subscriber]] = None,
-        properties: Optional[Dict] = None,
-        scenario_id: Optional[ScenarioId] = None,
-        version: Optional[str] = None,
-    ) -> Sequence:
-        sequence_id = Sequence._new_id(sequence_name, scenario_id)
-
+    @staticmethod
+    def __get_sequence_tasks(tasks: Union[List[Task], List[TaskId]]) -> List[Task]:
         task_manager = _TaskManagerFactory._build_manager()
         _tasks: List[Task] = []
         for task in tasks:
@@ -158,11 +149,24 @@ class _SequenceManager(_Manager[Sequence], _VersionMixin):
                 _tasks.append(_task)
             else:
                 raise NonExistingTask(task)
+        return _tasks
 
+    @classmethod
+    def _build_sequence(
+        cls,
+        sequence_name: str,
+        tasks: Union[List[Task], List[TaskId]],
+        subscribers: Optional[List[_Subscriber]] = None,
+        properties: Optional[Dict] = None,
+        scenario_id: Optional[ScenarioId] = None,
+        version: Optional[str] = None,
+    ) -> Sequence:
+        sequence_id = Sequence._new_id(sequence_name, scenario_id)
+        _tasks = cls.__get_sequence_tasks(tasks)
         properties = properties if properties else {}
         properties["name"] = sequence_name
         version = version if version else cls._get_latest_version()
-        sequence = Sequence(
+        return Sequence(
             properties=properties,
             tasks=_tasks,
             sequence_id=sequence_id,
@@ -171,10 +175,52 @@ class _SequenceManager(_Manager[Sequence], _VersionMixin):
             subscribers=subscribers,
             version=version,
         )
+
+    @classmethod
+    def _bulk_create_from_scenario(cls, scenario: Scenario) -> Dict[str, Sequence]:
+        _sequences: Dict[str, Sequence] = {}
+
+        for sequence_name, sequence_data in scenario._sequences.items():
+            sequence = cls._create(
+                sequence_name,
+                sequence_data.get(scenario._SEQUENCE_TASKS_KEY, []),
+                sequence_data.get(scenario._SEQUENCE_SUBSCRIBERS_KEY, []),
+                sequence_data.get(scenario._SEQUENCE_PROPERTIES_KEY, {}),
+                scenario.id,
+                scenario.version,
+            )
+            if not isinstance(sequence, Sequence):
+                raise NonExistingSequence(sequence_name, scenario.id)
+            _sequences[sequence_name] = sequence
+
+            Notifier.publish(_make_event(sequence, EventOperation.CREATION))
+
+        return _sequences
+
+    @classmethod
+    def _create(
+        cls,
+        sequence_name: str,
+        tasks: Union[List[Task], List[TaskId]],
+        subscribers: Optional[List[_Subscriber]] = None,
+        properties: Optional[Dict] = None,
+        scenario_id: Optional[ScenarioId] = None,
+        version: Optional[str] = None,
+    ) -> Sequence:
+        task_manager = _TaskManagerFactory._build_manager()
+        _tasks = cls.__get_sequence_tasks(tasks)
+
+        sequence = cls._build_sequence(sequence_name, _tasks, subscribers, properties, scenario_id, version)
+        sequence_id = sequence.id
+
         for task in _tasks:
             if sequence_id not in task._parent_ids:
                 task._parent_ids.update([sequence_id])
                 task_manager._set(task)
+
+        if not sequence._is_consistent():
+            raise InvalidSequence(sequence.id)
+
         return sequence
 
     @classmethod

+ 17 - 14
tests/core/sequence/test_sequence_manager.py

@@ -73,14 +73,15 @@ def test_raise_sequence_does_not_belong_to_scenario():
 def __init():
     input_dn = InMemoryDataNode("foo", Scope.SCENARIO)
     output_dn = InMemoryDataNode("foo", Scope.SCENARIO)
-    task = Task("task", {}, print, [input_dn], [output_dn], TaskId("task_id"))
+    task = Task("task", {}, print, [input_dn], [output_dn], TaskId("Task_task_id"))
+    _TaskManager._set(task)
     scenario = Scenario("scenario", {task}, {}, set())
     _ScenarioManager._set(scenario)
     return scenario, task
 
 
 def test_set_and_get_sequence_no_existing_sequence():
-    scenario, task = __init()
+    scenario, _ = __init()
     sequence_name_1 = "p1"
     sequence_id_1 = SequenceId(f"SEQUENCE_{sequence_name_1}_{scenario.id}")
     sequence_name_2 = "p2"
@@ -135,6 +136,19 @@ def test_set_and_get():
     assert _TaskManager._get(task.id).id == task.id
 
 
+def test_task_parent_id_set_only_when_create():
+    scenario, task = __init()
+    sequence_name_1 = "p1"
+
+    with mock.patch("taipy.core.task._task_manager._TaskManager._set") as mck:
+        scenario.add_sequences({sequence_name_1: [task]})
+        mck.assert_called_once()
+
+    with mock.patch("taipy.core.task._task_manager._TaskManager._set") as mck:
+        scenario.sequences[sequence_name_1]
+        mck.assert_not_called()
+
+
 def test_get_all_on_multiple_versions_environment():
     # Create 5 sequences from Scenario with 2 versions each
     for version in range(1, 3):
@@ -474,18 +488,7 @@ def test_sequence_notification_subscribe(mocker):
     mocker.patch.object(
         _utils,
         "_load_fct",
-        side_effect=[
-            notify_1,
-            notify_1,
-            notify_1,
-            notify_1,
-            notify_2,
-            notify_2,
-            notify_2,
-            notify_2,
-            notify_2,
-            notify_2,
-        ],
+        side_effect=[notify_1, notify_1, notify_2, notify_2, notify_2, notify_2],
     )
 
     # test subscription