123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558 |
- # Copyright 2021-2024 Avaiga Private Limited
- #
- # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
- # the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
- # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations under the License.
- import datetime
- import pathlib
- import tempfile
- import zipfile
- from functools import partial
- from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
- from taipy.config import Config
- from .._entity._entity_ids import _EntityIds
- from .._manager._manager import _Manager
- from .._repository._abstract_repository import _AbstractRepository
- from .._version._version_manager_factory import _VersionManagerFactory
- from .._version._version_mixin import _VersionMixin
- from ..common.warn_if_inputs_not_ready import _warn_if_inputs_not_ready
- from ..config.scenario_config import ScenarioConfig
- from ..cycle._cycle_manager_factory import _CycleManagerFactory
- from ..cycle.cycle import Cycle
- from ..data._data_manager_factory import _DataManagerFactory
- from ..exceptions.exceptions import (
- DeletingPrimaryScenario,
- DifferentScenarioConfigs,
- DoesNotBelongToACycle,
- EntitiesToBeImportAlredyExist,
- ImportArchiveDoesntContainAnyScenario,
- ImportScenarioDoesntHaveAVersion,
- InsufficientScenarioToCompare,
- InvalidScenario,
- NonExistingComparator,
- NonExistingScenario,
- NonExistingScenarioConfig,
- SequenceTaskConfigDoesNotExistInSameScenarioConfig,
- UnauthorizedTagError,
- )
- from ..job._job_manager_factory import _JobManagerFactory
- from ..job.job import Job
- from ..notification import EventEntityType, EventOperation, Notifier, _make_event
- from ..reason._reason_factory import _build_not_submittable_entity_reason, _build_wrong_config_type_reason
- from ..reason.reason import Reasons
- from ..submission._submission_manager_factory import _SubmissionManagerFactory
- from ..submission.submission import Submission
- from ..task._task_manager_factory import _TaskManagerFactory
- from .scenario import Scenario
- from .scenario_id import ScenarioId
- class _ScenarioManager(_Manager[Scenario], _VersionMixin):
- _AUTHORIZED_TAGS_KEY = "authorized_tags"
- _ENTITY_NAME = Scenario.__name__
- _EVENT_ENTITY_TYPE = EventEntityType.SCENARIO
- _repository: _AbstractRepository
- @classmethod
- def _get_all(cls, version_number: Optional[str] = None) -> List[Scenario]:
- """
- Returns all entities.
- """
- filters = cls._build_filters_with_version(version_number)
- return cls._repository._load_all(filters)
- @classmethod
- def _subscribe(
- cls,
- callback: Callable[[Scenario, Job], None],
- params: Optional[List[Any]] = None,
- scenario: Optional[Scenario] = None,
- ) -> None:
- if scenario is None:
- scenarios = cls._get_all()
- for scn in scenarios:
- cls.__add_subscriber(callback, params, scn)
- return
- cls.__add_subscriber(callback, params, scenario)
- @classmethod
- def _unsubscribe(
- cls,
- callback: Callable[[Scenario, Job], None],
- params: Optional[List[Any]] = None,
- scenario: Optional[Scenario] = None,
- ) -> None:
- if scenario is None:
- scenarios = cls._get_all()
- for scn in scenarios:
- cls.__remove_subscriber(callback, params, scn)
- return
- cls.__remove_subscriber(callback, params, scenario)
- @classmethod
- def __add_subscriber(cls, callback, params, scenario: Scenario) -> None:
- scenario._add_subscriber(callback, params)
- Notifier.publish(
- _make_event(scenario, EventOperation.UPDATE, attribute_name="subscribers", attribute_value=params)
- )
- @classmethod
- def __remove_subscriber(cls, callback, params, scenario: Scenario) -> None:
- scenario._remove_subscriber(callback, params)
- Notifier.publish(
- _make_event(scenario, EventOperation.UPDATE, attribute_name="subscribers", attribute_value=params)
- )
- @classmethod
- def _can_create(cls, config: ScenarioConfig) -> Reasons:
- config_id = getattr(config, "id", None) or str(config)
- reason = Reasons(config_id)
- if not isinstance(config, ScenarioConfig):
- reason._add_reason(config_id, _build_wrong_config_type_reason(config_id))
- return reason
- @classmethod
- def _create(
- cls,
- config: ScenarioConfig,
- creation_date: Optional[datetime.datetime] = None,
- name: Optional[str] = None,
- ) -> Scenario:
- _task_manager = _TaskManagerFactory._build_manager()
- _data_manager = _DataManagerFactory._build_manager()
- scenario_id = Scenario._new_id(str(config.id))
- cycle = (
- _CycleManagerFactory._build_manager()._get_or_create(config.frequency, creation_date)
- if config.frequency
- else None
- )
- cycle_id = cycle.id if cycle else None
- tasks = (
- _task_manager._bulk_get_or_create(config.task_configs, cycle_id, scenario_id) if config.task_configs else []
- )
- additional_data_nodes = (
- _data_manager._bulk_get_or_create(config.additional_data_node_configs, cycle_id, scenario_id)
- if config.additional_data_node_configs
- else {}
- )
- sequences = {}
- tasks_and_config_id_maps = {task.config_id: task for task in tasks}
- for sequence_name, sequence_task_configs in config.sequences.items():
- sequence_tasks = []
- non_existing_sequence_task_config_in_scenario_config = set()
- for sequence_task_config in sequence_task_configs:
- if task := tasks_and_config_id_maps.get(sequence_task_config.id):
- sequence_tasks.append(task)
- else:
- non_existing_sequence_task_config_in_scenario_config.add(sequence_task_config.id)
- if non_existing_sequence_task_config_in_scenario_config:
- raise SequenceTaskConfigDoesNotExistInSameScenarioConfig(
- list(non_existing_sequence_task_config_in_scenario_config), sequence_name, str(config.id)
- )
- sequences[sequence_name] = {Scenario._SEQUENCE_TASKS_KEY: sequence_tasks}
- is_primary_scenario = len(cls._get_all_by_cycle(cycle)) == 0 if cycle else False
- props = config._properties.copy()
- if name:
- props["name"] = name
- version = cls._get_latest_version()
- scenario = Scenario(
- config_id=str(config.id),
- tasks=set(tasks),
- properties=props,
- additional_data_nodes=set(additional_data_nodes.values()),
- scenario_id=scenario_id,
- creation_date=creation_date,
- is_primary=is_primary_scenario,
- cycle=cycle,
- version=version,
- sequences=sequences,
- )
- for task in tasks:
- if scenario_id not in task._parent_ids:
- task._parent_ids.update([scenario_id])
- _task_manager._set(task)
- for dn in additional_data_nodes.values():
- if scenario_id not in dn._parent_ids:
- dn._parent_ids.update([scenario_id])
- _data_manager._set(dn)
- cls._set(scenario)
- if not scenario._is_consistent():
- raise InvalidScenario(scenario.id)
- from ..sequence._sequence_manager_factory import _SequenceManagerFactory
- _SequenceManagerFactory._build_manager()._bulk_create_from_scenario(scenario)
- Notifier.publish(_make_event(scenario, EventOperation.CREATION))
- return scenario
- @classmethod
- def _is_submittable(cls, scenario: Union[Scenario, ScenarioId]) -> Reasons:
- if isinstance(scenario, str):
- scenario = cls._get(scenario)
- if not isinstance(scenario, Scenario):
- scenario = str(scenario)
- reason = Reasons((scenario))
- reason._add_reason(scenario, _build_not_submittable_entity_reason(scenario))
- return reason
- return scenario.is_ready_to_run()
- @classmethod
- def _submit(
- cls,
- scenario: Union[Scenario, ScenarioId],
- callbacks: Optional[List[Callable]] = None,
- force: bool = False,
- wait: bool = False,
- timeout: Optional[Union[float, int]] = None,
- check_inputs_are_ready: bool = True,
- **properties,
- ) -> Submission:
- scenario_id = scenario.id if isinstance(scenario, Scenario) else scenario
- if not isinstance(scenario, Scenario):
- scenario = cls._get(scenario_id)
- if scenario is None or not cls._exists(scenario_id):
- raise NonExistingScenario(scenario_id)
- callbacks = callbacks or []
- scenario_subscription_callback = cls.__get_status_notifier_callbacks(scenario) + callbacks
- if check_inputs_are_ready:
- _warn_if_inputs_not_ready(scenario.get_inputs())
- submission = (
- _TaskManagerFactory._build_manager()
- ._orchestrator()
- .submit(
- scenario,
- callbacks=scenario_subscription_callback,
- force=force,
- wait=wait,
- timeout=timeout,
- **properties,
- )
- )
- Notifier.publish(_make_event(scenario, EventOperation.SUBMISSION))
- return submission
- @classmethod
- def __get_status_notifier_callbacks(cls, scenario: Scenario) -> List:
- return [partial(c.callback, *c.params, scenario) for c in scenario.subscribers]
- @classmethod
- def _get_primary(cls, cycle: Cycle) -> Optional[Scenario]:
- scenarios = cls._get_all_by_cycle(cycle)
- for scenario in scenarios:
- if scenario.is_primary:
- return scenario
- return None
- @classmethod
- def _get_by_tag(cls, cycle: Cycle, tag: str) -> Optional[Scenario]:
- scenarios = cls._get_all_by_cycle(cycle)
- for scenario in scenarios:
- if scenario.has_tag(tag):
- return scenario
- return None
- @classmethod
- def _get_all_by_tag(cls, tag: str) -> List[Scenario]:
- return [scenario for scenario in cls._get_all() if scenario.has_tag(tag)]
- @classmethod
- def _get_all_by_cycle(cls, cycle: Cycle) -> List[Scenario]:
- filters = cls._build_filters_with_version("all")
- if not filters:
- filters = [{}]
- for fil in filters:
- fil.update({"cycle": cycle.id})
- return cls._get_all_by(filters)
- @classmethod
- def _get_primary_scenarios(cls) -> List[Scenario]:
- return [scenario for scenario in cls._get_all() if scenario.is_primary]
- @classmethod
- def _sort_scenarios(
- cls,
- scenarios: List[Scenario],
- descending: bool = False,
- sort_key: Literal["name", "id", "config_id", "creation_date", "tags"] = "name",
- ) -> List[Scenario]:
- if sort_key in ["name", "config_id", "creation_date", "tags"]:
- if sort_key == "tags":
- scenarios.sort(key=lambda x: (tuple(sorted(x.tags)), x.id), reverse=descending)
- else:
- scenarios.sort(key=lambda x: (getattr(x, sort_key), x.id), reverse=descending)
- elif sort_key == "id":
- scenarios.sort(key=lambda x: x.id, reverse=descending)
- else:
- scenarios.sort(key=lambda x: (x.name, x.id), reverse=descending)
- return scenarios
- @classmethod
- def _is_promotable_to_primary(cls, scenario: Union[Scenario, ScenarioId]) -> bool:
- if isinstance(scenario, str):
- scenario = cls._get(scenario)
- if scenario and not scenario.is_primary and scenario.cycle:
- return True
- return False
- @classmethod
- def _set_primary(cls, scenario: Scenario) -> None:
- if not scenario.cycle:
- raise DoesNotBelongToACycle(
- f"Can't set scenario {scenario.id} to primary because it doesn't belong to a cycle."
- )
- primary_scenario = cls._get_primary(scenario.cycle)
- # To prevent SAME scenario updating out of Context Manager
- if primary_scenario and primary_scenario != scenario:
- primary_scenario.is_primary = False # type: ignore
- scenario.is_primary = True # type: ignore
- @classmethod
- def _tag(cls, scenario: Scenario, tag: str) -> None:
- tags = scenario.properties.get(cls._AUTHORIZED_TAGS_KEY, set())
- if len(tags) > 0 and tag not in tags:
- raise UnauthorizedTagError(f"Tag `{tag}` not authorized by scenario configuration `{scenario.config_id}`")
- if scenario.cycle:
- if old_tagged_scenario := cls._get_by_tag(scenario.cycle, tag):
- old_tagged_scenario.remove_tag(tag)
- cls._set(old_tagged_scenario)
- scenario._add_tag(tag)
- cls._set(scenario)
- Notifier.publish(
- _make_event(scenario, EventOperation.UPDATE, attribute_name="tags", attribute_value=scenario.tags)
- )
- @classmethod
- def _untag(cls, scenario: Scenario, tag: str) -> None:
- scenario._remove_tag(tag)
- cls._set(scenario)
- Notifier.publish(
- _make_event(scenario, EventOperation.UPDATE, attribute_name="tags", attribute_value=scenario.tags)
- )
- @classmethod
- def _compare(cls, *scenarios: Scenario, data_node_config_id: Optional[str] = None) -> Dict:
- if len(scenarios) < 2:
- raise InsufficientScenarioToCompare("At least two scenarios are required to compare.")
- if not all(scenarios[0].config_id == scenario.config_id for scenario in scenarios):
- raise DifferentScenarioConfigs("Scenarios to compare must have the same configuration.")
- if scenario_config := cls.__get_config(scenarios[0]):
- results = {}
- if data_node_config_id:
- if data_node_config_id in scenario_config.comparators.keys():
- dn_comparators = {data_node_config_id: scenario_config.comparators[data_node_config_id]}
- else:
- raise NonExistingComparator(f"Data node config {data_node_config_id} has no comparator.")
- else:
- dn_comparators = scenario_config.comparators
- for data_node_config_id, comparators in dn_comparators.items():
- data_nodes = [scenario.__getattr__(data_node_config_id).read() for scenario in scenarios]
- results[data_node_config_id] = {
- comparator.__name__: comparator(*data_nodes) for comparator in comparators
- }
- return results
- else:
- raise NonExistingScenarioConfig(scenarios[0].config_id)
- @staticmethod
- def __get_config(scenario: Scenario):
- return Config.scenarios.get(scenario.config_id, None)
- @classmethod
- def _is_deletable(cls, scenario: Union[Scenario, ScenarioId]) -> bool:
- if isinstance(scenario, str):
- scenario = cls._get(scenario)
- if scenario.is_primary:
- if len(cls._get_all_by_cycle(scenario.cycle)) > 1:
- return False
- return True
- @classmethod
- def _delete(cls, scenario_id: ScenarioId) -> None:
- scenario = cls._get(scenario_id)
- if not cls._is_deletable(scenario):
- raise DeletingPrimaryScenario(
- f"Scenario {scenario.id}, which has config id {scenario.config_id}, is primary and there are "
- f"other scenarios in the same cycle. "
- )
- if scenario.is_primary:
- _CycleManagerFactory._build_manager()._delete(scenario.cycle.id)
- super()._delete(scenario_id)
- @classmethod
- def _hard_delete(cls, scenario_id: ScenarioId) -> None:
- scenario = cls._get(scenario_id)
- if not cls._is_deletable(scenario):
- raise DeletingPrimaryScenario(
- f"Scenario {scenario.id}, which has config id {scenario.config_id}, is primary and there are "
- f"other scenarios in the same cycle. "
- )
- if scenario.is_primary:
- _CycleManagerFactory._build_manager()._hard_delete(scenario.cycle.id)
- else:
- entity_ids_to_delete = cls._get_children_entity_ids(scenario)
- entity_ids_to_delete.scenario_ids.add(scenario.id)
- cls._delete_entities_of_multiple_types(entity_ids_to_delete)
- @classmethod
- def _delete_by_version(cls, version_number: str) -> None:
- """
- Deletes scenario by the version number.
- Check if the cycle is only attached to this scenario, then delete it.
- """
- for scenario in cls._repository._search("version", version_number):
- if scenario.cycle and len(cls._get_all_by_cycle(scenario.cycle)) == 1:
- _CycleManagerFactory._build_manager()._delete(scenario.cycle.id)
- super()._delete(scenario.id)
- @classmethod
- def _get_children_entity_ids(cls, scenario: Scenario) -> _EntityIds:
- entity_ids = _EntityIds()
- for sequence in scenario.sequences.values():
- if sequence.owner_id == scenario.id:
- entity_ids.sequence_ids.add(sequence.id)
- for task in scenario.tasks.values():
- if task.owner_id == scenario.id:
- entity_ids.task_ids.add(task.id)
- for data_node in scenario.data_nodes.values():
- if data_node.owner_id == scenario.id:
- entity_ids.data_node_ids.add(data_node.id)
- jobs = _JobManagerFactory._build_manager()._get_all()
- for job in jobs:
- if job.task.id in entity_ids.task_ids:
- entity_ids.job_ids.add(job.id)
- submissions = _SubmissionManagerFactory._build_manager()._get_all()
- submitted_entity_ids = list(entity_ids.scenario_ids.union(entity_ids.sequence_ids, entity_ids.task_ids))
- for submission in submissions:
- if submission.entity_id in submitted_entity_ids or submission.entity_id == scenario.id:
- entity_ids.submission_ids.add(submission.id)
- return entity_ids
- @classmethod
- def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) -> List[Scenario]:
- """
- Get all scenarios by its config id.
- """
- filters = cls._build_filters_with_version(version_number)
- if not filters:
- filters = [{}]
- for fil in filters:
- fil.update({"config_id": config_id})
- return cls._repository._load_all(filters)
- @classmethod
- def _import_scenario_and_children_entities(
- cls,
- zip_file_path: pathlib.Path,
- override: bool,
- entity_managers: Dict[str, Type[_Manager]],
- ) -> Optional[Scenario]:
- with tempfile.TemporaryDirectory() as tmp_dir:
- with zipfile.ZipFile(zip_file_path) as zip_file:
- zip_file.extractall(tmp_dir)
- tmp_dir_path = pathlib.Path(tmp_dir)
- if not ((tmp_dir_path / "scenarios").exists() or (tmp_dir_path / "scenario").exists()):
- raise ImportArchiveDoesntContainAnyScenario(zip_file_path)
- if not (tmp_dir_path / "version").exists():
- raise ImportScenarioDoesntHaveAVersion(zip_file_path)
- # Import the version to check for compatibility
- entity_managers["version"]._import(next((tmp_dir_path / "version").iterdir()), "")
- valid_entity_folders = list(entity_managers.keys())
- valid_data_folder = Config.core.storage_folder
- imported_scenario = None
- imported_entities: Dict[str, List] = {}
- for entity_folder in tmp_dir_path.iterdir():
- if not entity_folder.is_dir() or entity_folder.name not in valid_entity_folders + [valid_data_folder]:
- cls._logger.warning(f"{entity_folder} is not a valid Taipy folder and will not be imported.")
- continue
- try:
- for entity_type in valid_entity_folders:
- # Skip the version folder as it is already handled
- if entity_type == "version":
- continue
- entity_folder = tmp_dir_path / entity_type
- if not entity_folder.exists():
- continue
- manager = entity_managers[entity_type]
- imported_entities[entity_type] = []
- for entity_file in entity_folder.iterdir():
- # Check if the to-be-imported entity already exists
- entity_id = entity_file.stem
- if manager._exists(entity_id):
- if override:
- cls._logger.warning(f"{entity_id} already exists and will be overridden.")
- else:
- cls._logger.error(
- f"{entity_id} already exists. Please use the 'override' parameter to override it."
- )
- raise EntitiesToBeImportAlredyExist(zip_file_path)
- # Import the entity
- imported_entity = manager._import(
- entity_file,
- version=_VersionManagerFactory._build_manager()._get_latest_version(),
- data_folder=tmp_dir_path / valid_data_folder,
- )
- imported_entities[entity_type].append(imported_entity.id)
- if entity_type in ["scenario", "scenarios"]:
- imported_scenario = imported_entity
- except Exception as err:
- cls._logger.error(f"An error occurred during the import: {err}. Rollback the import.")
- # Rollback the import
- for entity_type, entity_ids in list(imported_entities.items())[::-1]:
- manager = entity_managers[entity_type]
- for entity_id in entity_ids:
- if manager._exists(entity_id):
- manager._delete(entity_id)
- raise err
- cls._logger.info(f"Scenario {imported_scenario.id} has been successfully imported.") # type: ignore[union-attr]
- return imported_scenario
|