_scenario_manager.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import datetime
  12. from functools import partial
  13. from typing import Any, Callable, List, Literal, Optional, Union
  14. from taipy.config import Config
  15. from .._entity._entity_ids import _EntityIds
  16. from .._manager._manager import _Manager
  17. from .._repository._abstract_repository import _AbstractRepository
  18. from .._version._version_mixin import _VersionMixin
  19. from ..common.warn_if_inputs_not_ready import _warn_if_inputs_not_ready
  20. from ..config.scenario_config import ScenarioConfig
  21. from ..cycle._cycle_manager_factory import _CycleManagerFactory
  22. from ..cycle.cycle import Cycle
  23. from ..data._data_manager_factory import _DataManagerFactory
  24. from ..exceptions.exceptions import (
  25. DeletingPrimaryScenario,
  26. DifferentScenarioConfigs,
  27. DoesNotBelongToACycle,
  28. InsufficientScenarioToCompare,
  29. InvalidSequence,
  30. InvalidSscenario,
  31. NonExistingComparator,
  32. NonExistingScenario,
  33. NonExistingScenarioConfig,
  34. SequenceTaskConfigDoesNotExistInSameScenarioConfig,
  35. UnauthorizedTagError,
  36. )
  37. from ..job._job_manager_factory import _JobManagerFactory
  38. from ..job.job import Job
  39. from ..notification import EventEntityType, EventOperation, Notifier, _make_event
  40. from ..submission._submission_manager_factory import _SubmissionManagerFactory
  41. from ..submission.submission import Submission
  42. from ..task._task_manager_factory import _TaskManagerFactory
  43. from .scenario import Scenario
  44. from .scenario_id import ScenarioId
  45. class _ScenarioManager(_Manager[Scenario], _VersionMixin):
  46. _AUTHORIZED_TAGS_KEY = "authorized_tags"
  47. _ENTITY_NAME = Scenario.__name__
  48. _EVENT_ENTITY_TYPE = EventEntityType.SCENARIO
  49. _repository: _AbstractRepository
  50. @classmethod
  51. def _get_all(cls, version_number: Optional[str] = None) -> List[Scenario]:
  52. """
  53. Returns all entities.
  54. """
  55. filters = cls._build_filters_with_version(version_number)
  56. return cls._repository._load_all(filters)
  57. @classmethod
  58. def _subscribe(
  59. cls,
  60. callback: Callable[[Scenario, Job], None],
  61. params: Optional[List[Any]] = None,
  62. scenario: Optional[Scenario] = None,
  63. ):
  64. if scenario is None:
  65. scenarios = cls._get_all()
  66. for scn in scenarios:
  67. cls.__add_subscriber(callback, params, scn)
  68. return
  69. cls.__add_subscriber(callback, params, scenario)
  70. @classmethod
  71. def _unsubscribe(
  72. cls,
  73. callback: Callable[[Scenario, Job], None],
  74. params: Optional[List[Any]] = None,
  75. scenario: Optional[Scenario] = None,
  76. ):
  77. if scenario is None:
  78. scenarios = cls._get_all()
  79. for scn in scenarios:
  80. cls.__remove_subscriber(callback, params, scn)
  81. return
  82. cls.__remove_subscriber(callback, params, scenario)
  83. @classmethod
  84. def __add_subscriber(cls, callback, params, scenario: Scenario):
  85. scenario._add_subscriber(callback, params)
  86. Notifier.publish(
  87. _make_event(scenario, EventOperation.UPDATE, attribute_name="subscribers", attribute_value=params)
  88. )
  89. @classmethod
  90. def __remove_subscriber(cls, callback, params, scenario: Scenario):
  91. scenario._remove_subscriber(callback, params)
  92. Notifier.publish(
  93. _make_event(scenario, EventOperation.UPDATE, attribute_name="subscribers", attribute_value=params)
  94. )
  95. @classmethod
  96. def _create(
  97. cls,
  98. config: ScenarioConfig,
  99. creation_date: Optional[datetime.datetime] = None,
  100. name: Optional[str] = None,
  101. ) -> Scenario:
  102. _task_manager = _TaskManagerFactory._build_manager()
  103. _data_manager = _DataManagerFactory._build_manager()
  104. scenario_id = Scenario._new_id(str(config.id))
  105. cycle = (
  106. _CycleManagerFactory._build_manager()._get_or_create(config.frequency, creation_date)
  107. if config.frequency
  108. else None
  109. )
  110. cycle_id = cycle.id if cycle else None
  111. tasks = (
  112. _task_manager._bulk_get_or_create(config.task_configs, cycle_id, scenario_id) if config.task_configs else []
  113. )
  114. additional_data_nodes = (
  115. _data_manager._bulk_get_or_create(config.additional_data_node_configs, cycle_id, scenario_id)
  116. if config.additional_data_node_configs
  117. else {}
  118. )
  119. sequences = {}
  120. tasks_and_config_id_maps = {task.config_id: task for task in tasks}
  121. for sequence_name, sequence_task_configs in config.sequences.items():
  122. sequence_tasks = []
  123. non_existing_sequence_task_config_in_scenario_config = set()
  124. for sequence_task_config in sequence_task_configs:
  125. if task := tasks_and_config_id_maps.get(sequence_task_config.id):
  126. sequence_tasks.append(task)
  127. else:
  128. non_existing_sequence_task_config_in_scenario_config.add(sequence_task_config.id)
  129. if non_existing_sequence_task_config_in_scenario_config:
  130. raise SequenceTaskConfigDoesNotExistInSameScenarioConfig(
  131. list(non_existing_sequence_task_config_in_scenario_config), sequence_name, str(config.id)
  132. )
  133. sequences[sequence_name] = {Scenario._SEQUENCE_TASKS_KEY: sequence_tasks}
  134. is_primary_scenario = len(cls._get_all_by_cycle(cycle)) == 0 if cycle else False
  135. props = config._properties.copy()
  136. if name:
  137. props["name"] = name
  138. version = cls._get_latest_version()
  139. scenario = Scenario(
  140. config_id=str(config.id),
  141. tasks=set(tasks),
  142. properties=props,
  143. additional_data_nodes=set(additional_data_nodes.values()),
  144. scenario_id=scenario_id,
  145. creation_date=creation_date,
  146. is_primary=is_primary_scenario,
  147. cycle=cycle,
  148. version=version,
  149. sequences=sequences,
  150. )
  151. for task in tasks:
  152. if scenario_id not in task._parent_ids:
  153. task._parent_ids.update([scenario_id])
  154. _task_manager._set(task)
  155. for dn in additional_data_nodes.values():
  156. if scenario_id not in dn._parent_ids:
  157. dn._parent_ids.update([scenario_id])
  158. _data_manager._set(dn)
  159. cls._set(scenario)
  160. if not scenario._is_consistent():
  161. raise InvalidSscenario(scenario.id)
  162. actual_sequences = scenario._get_sequences()
  163. for sequence_name in sequences.keys():
  164. if not actual_sequences[sequence_name]._is_consistent():
  165. raise InvalidSequence(actual_sequences[sequence_name].id)
  166. Notifier.publish(_make_event(actual_sequences[sequence_name], EventOperation.CREATION))
  167. Notifier.publish(_make_event(scenario, EventOperation.CREATION))
  168. return scenario
  169. @classmethod
  170. def _is_submittable(cls, scenario: Union[Scenario, ScenarioId]) -> bool:
  171. if isinstance(scenario, str):
  172. scenario = cls._get(scenario)
  173. return isinstance(scenario, Scenario) and scenario.is_ready_to_run()
  174. @classmethod
  175. def _submit(
  176. cls,
  177. scenario: Union[Scenario, ScenarioId],
  178. callbacks: Optional[List[Callable]] = None,
  179. force: bool = False,
  180. wait: bool = False,
  181. timeout: Optional[Union[float, int]] = None,
  182. check_inputs_are_ready: bool = True,
  183. **properties,
  184. ) -> Submission:
  185. scenario_id = scenario.id if isinstance(scenario, Scenario) else scenario
  186. scenario = cls._get(scenario_id)
  187. if scenario is None:
  188. raise NonExistingScenario(scenario_id)
  189. callbacks = callbacks or []
  190. scenario_subscription_callback = cls.__get_status_notifier_callbacks(scenario) + callbacks
  191. if check_inputs_are_ready:
  192. _warn_if_inputs_not_ready(scenario.get_inputs())
  193. submission = (
  194. _TaskManagerFactory._build_manager()
  195. ._orchestrator()
  196. .submit(
  197. scenario,
  198. callbacks=scenario_subscription_callback,
  199. force=force,
  200. wait=wait,
  201. timeout=timeout,
  202. **properties,
  203. )
  204. )
  205. Notifier.publish(_make_event(scenario, EventOperation.SUBMISSION))
  206. return submission
  207. @classmethod
  208. def __get_status_notifier_callbacks(cls, scenario: Scenario) -> List:
  209. return [partial(c.callback, *c.params, scenario) for c in scenario.subscribers]
  210. @classmethod
  211. def _get_primary(cls, cycle: Cycle) -> Optional[Scenario]:
  212. scenarios = cls._get_all_by_cycle(cycle)
  213. for scenario in scenarios:
  214. if scenario.is_primary:
  215. return scenario
  216. return None
  217. @classmethod
  218. def _get_by_tag(cls, cycle: Cycle, tag: str) -> Optional[Scenario]:
  219. scenarios = cls._get_all_by_cycle(cycle)
  220. for scenario in scenarios:
  221. if scenario.has_tag(tag):
  222. return scenario
  223. return None
  224. @classmethod
  225. def _get_all_by_tag(cls, tag: str) -> List[Scenario]:
  226. return [scenario for scenario in cls._get_all() if scenario.has_tag(tag)]
  227. @classmethod
  228. def _get_all_by_cycle(cls, cycle: Cycle) -> List[Scenario]:
  229. filters = cls._build_filters_with_version("all")
  230. if not filters:
  231. filters = [{}]
  232. for fil in filters:
  233. fil.update({"cycle": cycle.id})
  234. return cls._get_all_by(filters)
  235. @classmethod
  236. def _get_primary_scenarios(cls) -> List[Scenario]:
  237. return [scenario for scenario in cls._get_all() if scenario.is_primary]
  238. @classmethod
  239. def _sort_scenarios(
  240. cls,
  241. scenarios: List[Scenario],
  242. descending: bool = False,
  243. sort_key: Literal["name", "id", "config_id", "creation_date", "tags"] = "name",
  244. ) -> List[Scenario]:
  245. if sort_key in ["name", "config_id", "creation_date", "tags"]:
  246. if sort_key == "tags":
  247. scenarios.sort(key=lambda x: (tuple(sorted(x.tags)), x.id), reverse=descending)
  248. else:
  249. scenarios.sort(key=lambda x: (getattr(x, sort_key), x.id), reverse=descending)
  250. elif sort_key == "id":
  251. scenarios.sort(key=lambda x: x.id, reverse=descending)
  252. else:
  253. scenarios.sort(key=lambda x: (x.name, x.id), reverse=descending)
  254. return scenarios
  255. @classmethod
  256. def _is_promotable_to_primary(cls, scenario: Union[Scenario, ScenarioId]) -> bool:
  257. if isinstance(scenario, str):
  258. scenario = cls._get(scenario)
  259. if scenario and not scenario.is_primary and scenario.cycle:
  260. return True
  261. return False
  262. @classmethod
  263. def _set_primary(cls, scenario: Scenario):
  264. if not scenario.cycle:
  265. raise DoesNotBelongToACycle(
  266. f"Can't set scenario {scenario.id} to primary because it doesn't belong to a cycle."
  267. )
  268. primary_scenario = cls._get_primary(scenario.cycle)
  269. # To prevent SAME scenario updating out of Context Manager
  270. if primary_scenario and primary_scenario != scenario:
  271. primary_scenario.is_primary = False # type: ignore
  272. scenario.is_primary = True # type: ignore
  273. @classmethod
  274. def _tag(cls, scenario: Scenario, tag: str):
  275. tags = scenario.properties.get(cls._AUTHORIZED_TAGS_KEY, set())
  276. if len(tags) > 0 and tag not in tags:
  277. raise UnauthorizedTagError(f"Tag `{tag}` not authorized by scenario configuration `{scenario.config_id}`")
  278. if scenario.cycle:
  279. if old_tagged_scenario := cls._get_by_tag(scenario.cycle, tag):
  280. old_tagged_scenario.remove_tag(tag)
  281. cls._set(old_tagged_scenario)
  282. scenario._add_tag(tag)
  283. cls._set(scenario)
  284. Notifier.publish(
  285. _make_event(scenario, EventOperation.UPDATE, attribute_name="tags", attribute_value=scenario.tags)
  286. )
  287. @classmethod
  288. def _untag(cls, scenario: Scenario, tag: str):
  289. scenario._remove_tag(tag)
  290. cls._set(scenario)
  291. Notifier.publish(
  292. _make_event(scenario, EventOperation.UPDATE, attribute_name="tags", attribute_value=scenario.tags)
  293. )
  294. @classmethod
  295. def _compare(cls, *scenarios: Scenario, data_node_config_id: Optional[str] = None):
  296. if len(scenarios) < 2:
  297. raise InsufficientScenarioToCompare("At least two scenarios are required to compare.")
  298. if not all(scenarios[0].config_id == scenario.config_id for scenario in scenarios):
  299. raise DifferentScenarioConfigs("Scenarios to compare must have the same configuration.")
  300. if scenario_config := _ScenarioManager.__get_config(scenarios[0]):
  301. results = {}
  302. if data_node_config_id:
  303. if data_node_config_id in scenario_config.comparators.keys():
  304. dn_comparators = {data_node_config_id: scenario_config.comparators[data_node_config_id]}
  305. else:
  306. raise NonExistingComparator(f"Data node config {data_node_config_id} has no comparator.")
  307. else:
  308. dn_comparators = scenario_config.comparators
  309. for data_node_config_id, comparators in dn_comparators.items():
  310. data_nodes = [scenario.__getattr__(data_node_config_id).read() for scenario in scenarios]
  311. results[data_node_config_id] = {
  312. comparator.__name__: comparator(*data_nodes) for comparator in comparators
  313. }
  314. return results
  315. else:
  316. raise NonExistingScenarioConfig(scenarios[0].config_id)
  317. @staticmethod
  318. def __get_config(scenario: Scenario):
  319. return Config.scenarios.get(scenario.config_id, None)
  320. @classmethod
  321. def _is_deletable(cls, scenario: Union[Scenario, ScenarioId]) -> bool:
  322. if isinstance(scenario, str):
  323. scenario = cls._get(scenario)
  324. if scenario.is_primary:
  325. if len(cls._get_all_by_cycle(scenario.cycle)) > 1:
  326. return False
  327. return True
  328. @classmethod
  329. def _delete(cls, scenario_id: ScenarioId):
  330. scenario = cls._get(scenario_id)
  331. if not cls._is_deletable(scenario):
  332. raise DeletingPrimaryScenario(
  333. f"Scenario {scenario.id}, which has config id {scenario.config_id}, is primary and there are "
  334. f"other scenarios in the same cycle. "
  335. )
  336. if scenario.is_primary:
  337. _CycleManagerFactory._build_manager()._delete(scenario.cycle.id)
  338. super()._delete(scenario_id)
  339. @classmethod
  340. def _hard_delete(cls, scenario_id: ScenarioId):
  341. scenario = cls._get(scenario_id)
  342. if not cls._is_deletable(scenario):
  343. raise DeletingPrimaryScenario(
  344. f"Scenario {scenario.id}, which has config id {scenario.config_id}, is primary and there are "
  345. f"other scenarios in the same cycle. "
  346. )
  347. if scenario.is_primary:
  348. _CycleManagerFactory._build_manager()._hard_delete(scenario.cycle.id)
  349. else:
  350. entity_ids_to_delete = cls._get_children_entity_ids(scenario)
  351. entity_ids_to_delete.scenario_ids.add(scenario.id)
  352. cls._delete_entities_of_multiple_types(entity_ids_to_delete)
  353. @classmethod
  354. def _delete_by_version(cls, version_number: str):
  355. """
  356. Deletes scenario by the version number.
  357. Check if the cycle is only attached to this scenario, then delete it.
  358. """
  359. for scenario in cls._repository._search("version", version_number):
  360. if scenario.cycle and len(cls._get_all_by_cycle(scenario.cycle)) == 1:
  361. _CycleManagerFactory._build_manager()._delete(scenario.cycle.id)
  362. super()._delete(scenario.id)
  363. @classmethod
  364. def _get_children_entity_ids(cls, scenario: Scenario) -> _EntityIds:
  365. entity_ids = _EntityIds()
  366. for sequence in scenario.sequences.values():
  367. if sequence.owner_id == scenario.id:
  368. entity_ids.sequence_ids.add(sequence.id)
  369. for task in scenario.tasks.values():
  370. if task.owner_id == scenario.id:
  371. entity_ids.task_ids.add(task.id)
  372. for data_node in scenario.data_nodes.values():
  373. if data_node.owner_id == scenario.id:
  374. entity_ids.data_node_ids.add(data_node.id)
  375. jobs = _JobManagerFactory._build_manager()._get_all()
  376. for job in jobs:
  377. if job.task.id in entity_ids.task_ids:
  378. entity_ids.job_ids.add(job.id)
  379. submissions = _SubmissionManagerFactory._build_manager()._get_all()
  380. submitted_entity_ids = list(entity_ids.scenario_ids.union(entity_ids.sequence_ids, entity_ids.task_ids))
  381. for submission in submissions:
  382. if submission.entity_id in submitted_entity_ids or submission.entity_id == scenario.id:
  383. entity_ids.submission_ids.add(submission.id)
  384. return entity_ids
  385. @classmethod
  386. def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) -> List[Scenario]:
  387. """
  388. Get all scenarios by its config id.
  389. """
  390. filters = cls._build_filters_with_version(version_number)
  391. if not filters:
  392. filters = [{}]
  393. for fil in filters:
  394. fil.update({"config_id": config_id})
  395. return cls._repository._load_all(filters)