_scenario_manager.py 22 KB


  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import datetime
  12. import pathlib
  13. import tempfile
  14. import zipfile
  15. from functools import partial
  16. from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
  17. from taipy.config import Config
  18. from .._entity._entity_ids import _EntityIds
  19. from .._manager._manager import _Manager
  20. from .._repository._abstract_repository import _AbstractRepository
  21. from .._version._version_manager_factory import _VersionManagerFactory
  22. from .._version._version_mixin import _VersionMixin
  23. from ..common.warn_if_inputs_not_ready import _warn_if_inputs_not_ready
  24. from ..config.scenario_config import ScenarioConfig
  25. from ..cycle._cycle_manager_factory import _CycleManagerFactory
  26. from ..cycle.cycle import Cycle
  27. from ..data._data_manager_factory import _DataManagerFactory
  28. from ..exceptions.exceptions import (
  29. DeletingPrimaryScenario,
  30. DifferentScenarioConfigs,
  31. DoesNotBelongToACycle,
  32. EntitiesToBeImportAlredyExist,
  33. ImportArchiveDoesntContainAnyScenario,
  34. ImportScenarioDoesntHaveAVersion,
  35. InsufficientScenarioToCompare,
  36. InvalidScenario,
  37. NonExistingComparator,
  38. NonExistingScenario,
  39. NonExistingScenarioConfig,
  40. SequenceTaskConfigDoesNotExistInSameScenarioConfig,
  41. UnauthorizedTagError,
  42. )
  43. from ..job._job_manager_factory import _JobManagerFactory
  44. from ..job.job import Job
  45. from ..notification import EventEntityType, EventOperation, Notifier, _make_event
  46. from ..reason._reason_factory import _build_not_submittable_entity_reason, _build_wrong_config_type_reason
  47. from ..reason.reason import Reasons
  48. from ..submission._submission_manager_factory import _SubmissionManagerFactory
  49. from ..submission.submission import Submission
  50. from ..task._task_manager_factory import _TaskManagerFactory
  51. from .scenario import Scenario
  52. from .scenario_id import ScenarioId
  53. class _ScenarioManager(_Manager[Scenario], _VersionMixin):
  54. _AUTHORIZED_TAGS_KEY = "authorized_tags"
  55. _ENTITY_NAME = Scenario.__name__
  56. _EVENT_ENTITY_TYPE = EventEntityType.SCENARIO
  57. _repository: _AbstractRepository
  58. @classmethod
  59. def _get_all(cls, version_number: Optional[str] = None) -> List[Scenario]:
  60. """
  61. Returns all entities.
  62. """
  63. filters = cls._build_filters_with_version(version_number)
  64. return cls._repository._load_all(filters)
  65. @classmethod
  66. def _subscribe(
  67. cls,
  68. callback: Callable[[Scenario, Job], None],
  69. params: Optional[List[Any]] = None,
  70. scenario: Optional[Scenario] = None,
  71. ) -> None:
  72. if scenario is None:
  73. scenarios = cls._get_all()
  74. for scn in scenarios:
  75. cls.__add_subscriber(callback, params, scn)
  76. return
  77. cls.__add_subscriber(callback, params, scenario)
  78. @classmethod
  79. def _unsubscribe(
  80. cls,
  81. callback: Callable[[Scenario, Job], None],
  82. params: Optional[List[Any]] = None,
  83. scenario: Optional[Scenario] = None,
  84. ) -> None:
  85. if scenario is None:
  86. scenarios = cls._get_all()
  87. for scn in scenarios:
  88. cls.__remove_subscriber(callback, params, scn)
  89. return
  90. cls.__remove_subscriber(callback, params, scenario)
  91. @classmethod
  92. def __add_subscriber(cls, callback, params, scenario: Scenario) -> None:
  93. scenario._add_subscriber(callback, params)
  94. Notifier.publish(
  95. _make_event(scenario, EventOperation.UPDATE, attribute_name="subscribers", attribute_value=params)
  96. )
  97. @classmethod
  98. def __remove_subscriber(cls, callback, params, scenario: Scenario) -> None:
  99. scenario._remove_subscriber(callback, params)
  100. Notifier.publish(
  101. _make_event(scenario, EventOperation.UPDATE, attribute_name="subscribers", attribute_value=params)
  102. )
  103. @classmethod
  104. def _can_create(cls, config: ScenarioConfig) -> Reasons:
  105. config_id = getattr(config, "id", None) or str(config)
  106. reason = Reasons(config_id)
  107. if not isinstance(config, ScenarioConfig):
  108. reason._add_reason(config_id, _build_wrong_config_type_reason(config_id))
  109. return reason
  110. @classmethod
  111. def _create(
  112. cls,
  113. config: ScenarioConfig,
  114. creation_date: Optional[datetime.datetime] = None,
  115. name: Optional[str] = None,
  116. ) -> Scenario:
  117. _task_manager = _TaskManagerFactory._build_manager()
  118. _data_manager = _DataManagerFactory._build_manager()
  119. scenario_id = Scenario._new_id(str(config.id))
  120. cycle = (
  121. _CycleManagerFactory._build_manager()._get_or_create(config.frequency, creation_date)
  122. if config.frequency
  123. else None
  124. )
  125. cycle_id = cycle.id if cycle else None
  126. tasks = (
  127. _task_manager._bulk_get_or_create(config.task_configs, cycle_id, scenario_id) if config.task_configs else []
  128. )
  129. additional_data_nodes = (
  130. _data_manager._bulk_get_or_create(config.additional_data_node_configs, cycle_id, scenario_id)
  131. if config.additional_data_node_configs
  132. else {}
  133. )
  134. sequences = {}
  135. tasks_and_config_id_maps = {task.config_id: task for task in tasks}
  136. for sequence_name, sequence_task_configs in config.sequences.items():
  137. sequence_tasks = []
  138. non_existing_sequence_task_config_in_scenario_config = set()
  139. for sequence_task_config in sequence_task_configs:
  140. if task := tasks_and_config_id_maps.get(sequence_task_config.id):
  141. sequence_tasks.append(task)
  142. else:
  143. non_existing_sequence_task_config_in_scenario_config.add(sequence_task_config.id)
  144. if non_existing_sequence_task_config_in_scenario_config:
  145. raise SequenceTaskConfigDoesNotExistInSameScenarioConfig(
  146. list(non_existing_sequence_task_config_in_scenario_config), sequence_name, str(config.id)
  147. )
  148. sequences[sequence_name] = {Scenario._SEQUENCE_TASKS_KEY: sequence_tasks}
  149. is_primary_scenario = len(cls._get_all_by_cycle(cycle)) == 0 if cycle else False
  150. props = config._properties.copy()
  151. if name:
  152. props["name"] = name
  153. version = cls._get_latest_version()
  154. scenario = Scenario(
  155. config_id=str(config.id),
  156. tasks=set(tasks),
  157. properties=props,
  158. additional_data_nodes=set(additional_data_nodes.values()),
  159. scenario_id=scenario_id,
  160. creation_date=creation_date,
  161. is_primary=is_primary_scenario,
  162. cycle=cycle,
  163. version=version,
  164. sequences=sequences,
  165. )
  166. for task in tasks:
  167. if scenario_id not in task._parent_ids:
  168. task._parent_ids.update([scenario_id])
  169. _task_manager._set(task)
  170. for dn in additional_data_nodes.values():
  171. if scenario_id not in dn._parent_ids:
  172. dn._parent_ids.update([scenario_id])
  173. _data_manager._set(dn)
  174. cls._set(scenario)
  175. if not scenario._is_consistent():
  176. raise InvalidScenario(scenario.id)
  177. from ..sequence._sequence_manager_factory import _SequenceManagerFactory
  178. _SequenceManagerFactory._build_manager()._bulk_create_from_scenario(scenario)
  179. Notifier.publish(_make_event(scenario, EventOperation.CREATION))
  180. return scenario
  181. @classmethod
  182. def _is_submittable(cls, scenario: Union[Scenario, ScenarioId]) -> Reasons:
  183. if isinstance(scenario, str):
  184. scenario = cls._get(scenario)
  185. if not isinstance(scenario, Scenario):
  186. scenario = str(scenario)
  187. reason = Reasons((scenario))
  188. reason._add_reason(scenario, _build_not_submittable_entity_reason(scenario))
  189. return reason
  190. return scenario.is_ready_to_run()
  191. @classmethod
  192. def _submit(
  193. cls,
  194. scenario: Union[Scenario, ScenarioId],
  195. callbacks: Optional[List[Callable]] = None,
  196. force: bool = False,
  197. wait: bool = False,
  198. timeout: Optional[Union[float, int]] = None,
  199. check_inputs_are_ready: bool = True,
  200. **properties,
  201. ) -> Submission:
  202. scenario_id = scenario.id if isinstance(scenario, Scenario) else scenario
  203. if not isinstance(scenario, Scenario):
  204. scenario = cls._get(scenario_id)
  205. if scenario is None or not cls._exists(scenario_id):
  206. raise NonExistingScenario(scenario_id)
  207. callbacks = callbacks or []
  208. scenario_subscription_callback = cls.__get_status_notifier_callbacks(scenario) + callbacks
  209. if check_inputs_are_ready:
  210. _warn_if_inputs_not_ready(scenario.get_inputs())
  211. submission = (
  212. _TaskManagerFactory._build_manager()
  213. ._orchestrator()
  214. .submit(
  215. scenario,
  216. callbacks=scenario_subscription_callback,
  217. force=force,
  218. wait=wait,
  219. timeout=timeout,
  220. **properties,
  221. )
  222. )
  223. Notifier.publish(_make_event(scenario, EventOperation.SUBMISSION))
  224. return submission
  225. @classmethod
  226. def __get_status_notifier_callbacks(cls, scenario: Scenario) -> List:
  227. return [partial(c.callback, *c.params, scenario) for c in scenario.subscribers]
  228. @classmethod
  229. def _get_primary(cls, cycle: Cycle) -> Optional[Scenario]:
  230. scenarios = cls._get_all_by_cycle(cycle)
  231. for scenario in scenarios:
  232. if scenario.is_primary:
  233. return scenario
  234. return None
  235. @classmethod
  236. def _get_by_tag(cls, cycle: Cycle, tag: str) -> Optional[Scenario]:
  237. scenarios = cls._get_all_by_cycle(cycle)
  238. for scenario in scenarios:
  239. if scenario.has_tag(tag):
  240. return scenario
  241. return None
  242. @classmethod
  243. def _get_all_by_tag(cls, tag: str) -> List[Scenario]:
  244. return [scenario for scenario in cls._get_all() if scenario.has_tag(tag)]
  245. @classmethod
  246. def _get_all_by_cycle(cls, cycle: Cycle) -> List[Scenario]:
  247. filters = cls._build_filters_with_version("all")
  248. if not filters:
  249. filters = [{}]
  250. for fil in filters:
  251. fil.update({"cycle": cycle.id})
  252. return cls._get_all_by(filters)
  253. @classmethod
  254. def _get_primary_scenarios(cls) -> List[Scenario]:
  255. return [scenario for scenario in cls._get_all() if scenario.is_primary]
  256. @classmethod
  257. def _sort_scenarios(
  258. cls,
  259. scenarios: List[Scenario],
  260. descending: bool = False,
  261. sort_key: Literal["name", "id", "config_id", "creation_date", "tags"] = "name",
  262. ) -> List[Scenario]:
  263. if sort_key in ["name", "config_id", "creation_date", "tags"]:
  264. if sort_key == "tags":
  265. scenarios.sort(key=lambda x: (tuple(sorted(x.tags)), x.id), reverse=descending)
  266. else:
  267. scenarios.sort(key=lambda x: (getattr(x, sort_key), x.id), reverse=descending)
  268. elif sort_key == "id":
  269. scenarios.sort(key=lambda x: x.id, reverse=descending)
  270. else:
  271. scenarios.sort(key=lambda x: (x.name, x.id), reverse=descending)
  272. return scenarios
  273. @classmethod
  274. def _is_promotable_to_primary(cls, scenario: Union[Scenario, ScenarioId]) -> bool:
  275. if isinstance(scenario, str):
  276. scenario = cls._get(scenario)
  277. if scenario and not scenario.is_primary and scenario.cycle:
  278. return True
  279. return False
  280. @classmethod
  281. def _set_primary(cls, scenario: Scenario) -> None:
  282. if not scenario.cycle:
  283. raise DoesNotBelongToACycle(
  284. f"Can't set scenario {scenario.id} to primary because it doesn't belong to a cycle."
  285. )
  286. primary_scenario = cls._get_primary(scenario.cycle)
  287. # To prevent SAME scenario updating out of Context Manager
  288. if primary_scenario and primary_scenario != scenario:
  289. primary_scenario.is_primary = False # type: ignore
  290. scenario.is_primary = True # type: ignore
  291. @classmethod
  292. def _tag(cls, scenario: Scenario, tag: str) -> None:
  293. tags = scenario.properties.get(cls._AUTHORIZED_TAGS_KEY, set())
  294. if len(tags) > 0 and tag not in tags:
  295. raise UnauthorizedTagError(f"Tag `{tag}` not authorized by scenario configuration `{scenario.config_id}`")
  296. if scenario.cycle:
  297. if old_tagged_scenario := cls._get_by_tag(scenario.cycle, tag):
  298. old_tagged_scenario.remove_tag(tag)
  299. cls._set(old_tagged_scenario)
  300. scenario._add_tag(tag)
  301. cls._set(scenario)
  302. Notifier.publish(
  303. _make_event(scenario, EventOperation.UPDATE, attribute_name="tags", attribute_value=scenario.tags)
  304. )
  305. @classmethod
  306. def _untag(cls, scenario: Scenario, tag: str) -> None:
  307. scenario._remove_tag(tag)
  308. cls._set(scenario)
  309. Notifier.publish(
  310. _make_event(scenario, EventOperation.UPDATE, attribute_name="tags", attribute_value=scenario.tags)
  311. )
  312. @classmethod
  313. def _compare(cls, *scenarios: Scenario, data_node_config_id: Optional[str] = None) -> Dict:
  314. if len(scenarios) < 2:
  315. raise InsufficientScenarioToCompare("At least two scenarios are required to compare.")
  316. if not all(scenarios[0].config_id == scenario.config_id for scenario in scenarios):
  317. raise DifferentScenarioConfigs("Scenarios to compare must have the same configuration.")
  318. if scenario_config := cls.__get_config(scenarios[0]):
  319. results = {}
  320. if data_node_config_id:
  321. if data_node_config_id in scenario_config.comparators.keys():
  322. dn_comparators = {data_node_config_id: scenario_config.comparators[data_node_config_id]}
  323. else:
  324. raise NonExistingComparator(f"Data node config {data_node_config_id} has no comparator.")
  325. else:
  326. dn_comparators = scenario_config.comparators
  327. for data_node_config_id, comparators in dn_comparators.items():
  328. data_nodes = [scenario.__getattr__(data_node_config_id).read() for scenario in scenarios]
  329. results[data_node_config_id] = {
  330. comparator.__name__: comparator(*data_nodes) for comparator in comparators
  331. }
  332. return results
  333. else:
  334. raise NonExistingScenarioConfig(scenarios[0].config_id)
  335. @staticmethod
  336. def __get_config(scenario: Scenario):
  337. return Config.scenarios.get(scenario.config_id, None)
  338. @classmethod
  339. def _is_deletable(cls, scenario: Union[Scenario, ScenarioId]) -> bool:
  340. if isinstance(scenario, str):
  341. scenario = cls._get(scenario)
  342. if scenario.is_primary:
  343. if len(cls._get_all_by_cycle(scenario.cycle)) > 1:
  344. return False
  345. return True
  346. @classmethod
  347. def _delete(cls, scenario_id: ScenarioId) -> None:
  348. scenario = cls._get(scenario_id)
  349. if not cls._is_deletable(scenario):
  350. raise DeletingPrimaryScenario(
  351. f"Scenario {scenario.id}, which has config id {scenario.config_id}, is primary and there are "
  352. f"other scenarios in the same cycle. "
  353. )
  354. if scenario.is_primary:
  355. _CycleManagerFactory._build_manager()._delete(scenario.cycle.id)
  356. super()._delete(scenario_id)
  357. @classmethod
  358. def _hard_delete(cls, scenario_id: ScenarioId) -> None:
  359. scenario = cls._get(scenario_id)
  360. if not cls._is_deletable(scenario):
  361. raise DeletingPrimaryScenario(
  362. f"Scenario {scenario.id}, which has config id {scenario.config_id}, is primary and there are "
  363. f"other scenarios in the same cycle. "
  364. )
  365. if scenario.is_primary:
  366. _CycleManagerFactory._build_manager()._hard_delete(scenario.cycle.id)
  367. else:
  368. entity_ids_to_delete = cls._get_children_entity_ids(scenario)
  369. entity_ids_to_delete.scenario_ids.add(scenario.id)
  370. cls._delete_entities_of_multiple_types(entity_ids_to_delete)
  371. @classmethod
  372. def _delete_by_version(cls, version_number: str) -> None:
  373. """
  374. Deletes scenario by the version number.
  375. Check if the cycle is only attached to this scenario, then delete it.
  376. """
  377. for scenario in cls._repository._search("version", version_number):
  378. if scenario.cycle and len(cls._get_all_by_cycle(scenario.cycle)) == 1:
  379. _CycleManagerFactory._build_manager()._delete(scenario.cycle.id)
  380. super()._delete(scenario.id)
  381. @classmethod
  382. def _get_children_entity_ids(cls, scenario: Scenario) -> _EntityIds:
  383. entity_ids = _EntityIds()
  384. for sequence in scenario.sequences.values():
  385. if sequence.owner_id == scenario.id:
  386. entity_ids.sequence_ids.add(sequence.id)
  387. for task in scenario.tasks.values():
  388. if task.owner_id == scenario.id:
  389. entity_ids.task_ids.add(task.id)
  390. for data_node in scenario.data_nodes.values():
  391. if data_node.owner_id == scenario.id:
  392. entity_ids.data_node_ids.add(data_node.id)
  393. jobs = _JobManagerFactory._build_manager()._get_all()
  394. for job in jobs:
  395. if job.task.id in entity_ids.task_ids:
  396. entity_ids.job_ids.add(job.id)
  397. submissions = _SubmissionManagerFactory._build_manager()._get_all()
  398. submitted_entity_ids = list(entity_ids.scenario_ids.union(entity_ids.sequence_ids, entity_ids.task_ids))
  399. for submission in submissions:
  400. if submission.entity_id in submitted_entity_ids or submission.entity_id == scenario.id:
  401. entity_ids.submission_ids.add(submission.id)
  402. return entity_ids
  403. @classmethod
  404. def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) -> List[Scenario]:
  405. """
  406. Get all scenarios by its config id.
  407. """
  408. filters = cls._build_filters_with_version(version_number)
  409. if not filters:
  410. filters = [{}]
  411. for fil in filters:
  412. fil.update({"config_id": config_id})
  413. return cls._repository._load_all(filters)
  414. @classmethod
  415. def _import_scenario_and_children_entities(
  416. cls,
  417. zip_file_path: pathlib.Path,
  418. override: bool,
  419. entity_managers: Dict[str, Type[_Manager]],
  420. ) -> Optional[Scenario]:
  421. with tempfile.TemporaryDirectory() as tmp_dir:
  422. with zipfile.ZipFile(zip_file_path) as zip_file:
  423. zip_file.extractall(tmp_dir)
  424. tmp_dir_path = pathlib.Path(tmp_dir)
  425. if not ((tmp_dir_path / "scenarios").exists() or (tmp_dir_path / "scenario").exists()):
  426. raise ImportArchiveDoesntContainAnyScenario(zip_file_path)
  427. if not (tmp_dir_path / "version").exists():
  428. raise ImportScenarioDoesntHaveAVersion(zip_file_path)
  429. # Import the version to check for compatibility
  430. entity_managers["version"]._import(next((tmp_dir_path / "version").iterdir()), "")
  431. valid_entity_folders = list(entity_managers.keys())
  432. valid_data_folder = Config.core.storage_folder
  433. imported_scenario = None
  434. imported_entities: Dict[str, List] = {}
  435. for entity_folder in tmp_dir_path.iterdir():
  436. if not entity_folder.is_dir() or entity_folder.name not in valid_entity_folders + [valid_data_folder]:
  437. cls._logger.warning(f"{entity_folder} is not a valid Taipy folder and will not be imported.")
  438. continue
  439. try:
  440. for entity_type in valid_entity_folders:
  441. # Skip the version folder as it is already handled
  442. if entity_type == "version":
  443. continue
  444. entity_folder = tmp_dir_path / entity_type
  445. if not entity_folder.exists():
  446. continue
  447. manager = entity_managers[entity_type]
  448. imported_entities[entity_type] = []
  449. for entity_file in entity_folder.iterdir():
  450. # Check if the to-be-imported entity already exists
  451. entity_id = entity_file.stem
  452. if manager._exists(entity_id):
  453. if override:
  454. cls._logger.warning(f"{entity_id} already exists and will be overridden.")
  455. else:
  456. cls._logger.error(
  457. f"{entity_id} already exists. Please use the 'override' parameter to override it."
  458. )
  459. raise EntitiesToBeImportAlredyExist(zip_file_path)
  460. # Import the entity
  461. imported_entity = manager._import(
  462. entity_file,
  463. version=_VersionManagerFactory._build_manager()._get_latest_version(),
  464. data_folder=tmp_dir_path / valid_data_folder,
  465. )
  466. imported_entities[entity_type].append(imported_entity.id)
  467. if entity_type in ["scenario", "scenarios"]:
  468. imported_scenario = imported_entity
  469. except Exception as err:
  470. cls._logger.error(f"An error occurred during the import: {err}. Rollback the import.")
  471. # Rollback the import
  472. for entity_type, entity_ids in list(imported_entities.items())[::-1]:
  473. manager = entity_managers[entity_type]
  474. for entity_id in entity_ids:
  475. if manager._exists(entity_id):
  476. manager._delete(entity_id)
  477. raise err
  478. cls._logger.info(f"Scenario {imported_scenario.id} has been successfully imported.") # type: ignore[union-attr]
  479. return imported_scenario