_scenario_manager.py 22 KB


  1. # Copyright 2021-2025 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. from datetime import datetime
  12. from functools import partial
  13. from typing import Any, Callable, Dict, List, Literal, Optional, Set, Union
  14. from taipy.common.config import Config
  15. from .._entity._entity_ids import _EntityIds
  16. from .._manager._manager import _Manager
  17. from .._repository._abstract_repository import _AbstractRepository
  18. from .._version._version_mixin import _VersionMixin
  19. from ..common.warn_if_inputs_not_ready import _warn_if_inputs_not_ready
  20. from ..config.scenario_config import ScenarioConfig
  21. from ..cycle._cycle_manager_factory import _CycleManagerFactory
  22. from ..cycle.cycle import Cycle
  23. from ..data._data_manager_factory import _DataManagerFactory
  24. from ..exceptions.exceptions import (
  25. DeletingPrimaryScenario,
  26. DifferentScenarioConfigs,
  27. DoesNotBelongToACycle,
  28. InsufficientScenarioToCompare,
  29. InvalidScenario,
  30. NonExistingComparator,
  31. NonExistingScenario,
  32. NonExistingScenarioConfig,
  33. SequenceTaskConfigDoesNotExistInSameScenarioConfig,
  34. UnauthorizedTagError,
  35. )
  36. from ..job._job_manager_factory import _JobManagerFactory
  37. from ..job.job import Job
  38. from ..notification import EventEntityType, EventOperation, Notifier, _make_event
  39. from ..reason import (
  40. EntityDoesNotExist,
  41. EntityIsNotAScenario,
  42. EntityIsNotSubmittableEntity,
  43. ReasonCollection,
  44. ScenarioDoesNotBelongToACycle,
  45. ScenarioIsThePrimaryScenario,
  46. WrongConfigType,
  47. )
  48. from ..submission._submission_manager_factory import _SubmissionManagerFactory
  49. from ..submission.submission import Submission
  50. from ..task._task_manager_factory import _TaskManagerFactory
  51. from ._scenario_duplicator import _ScenarioDuplicator
  52. from .scenario import Scenario
  53. from .scenario_id import ScenarioId
  54. class _ScenarioManager(_Manager[Scenario], _VersionMixin):
  55. _AUTHORIZED_TAGS_KEY = "authorized_tags"
  56. _ENTITY_NAME = Scenario.__name__
  57. _EVENT_ENTITY_TYPE = EventEntityType.SCENARIO
  58. _repository: _AbstractRepository
  59. @classmethod
  60. def _get_all(cls, version_number: Optional[str] = None) -> List[Scenario]:
  61. """
  62. Returns all entities.
  63. """
  64. filters = cls._build_filters_with_version(version_number)
  65. return cls._repository._load_all(filters)
  66. @classmethod
  67. def _subscribe(
  68. cls,
  69. callback: Callable[[Scenario, Job], None],
  70. params: Optional[List[Any]] = None,
  71. scenario: Optional[Scenario] = None,
  72. ) -> None:
  73. if scenario is None:
  74. scenarios = cls._get_all()
  75. for scn in scenarios:
  76. cls.__add_subscriber(callback, params, scn)
  77. return
  78. cls.__add_subscriber(callback, params, scenario)
  79. @classmethod
  80. def _unsubscribe(
  81. cls,
  82. callback: Callable[[Scenario, Job], None],
  83. params: Optional[List[Any]] = None,
  84. scenario: Optional[Scenario] = None,
  85. ) -> None:
  86. if scenario is None:
  87. scenarios = cls._get_all()
  88. for scn in scenarios:
  89. cls.__remove_subscriber(callback, params, scn)
  90. return
  91. cls.__remove_subscriber(callback, params, scenario)
  92. @classmethod
  93. def __add_subscriber(cls, callback, params, scenario: Scenario) -> None:
  94. scenario._add_subscriber(callback, params)
  95. Notifier.publish(
  96. _make_event(scenario, EventOperation.UPDATE, attribute_name="subscribers", attribute_value=params)
  97. )
  98. @classmethod
  99. def __remove_subscriber(cls, callback, params, scenario: Scenario) -> None:
  100. scenario._remove_subscriber(callback, params)
  101. Notifier.publish(
  102. _make_event(scenario, EventOperation.UPDATE, attribute_name="subscribers", attribute_value=params)
  103. )
  104. @classmethod
  105. def _can_create(cls, config: Optional[ScenarioConfig] = None) -> ReasonCollection:
  106. config_id = getattr(config, "id", None) or str(config)
  107. reason_collector = ReasonCollection()
  108. if config is not None and not isinstance(config, ScenarioConfig):
  109. reason_collector._add_reason(config_id, WrongConfigType(config_id, ScenarioConfig.__name__))
  110. return reason_collector
  111. @classmethod
  112. def _create(
  113. cls,
  114. config: ScenarioConfig,
  115. creation_date: Optional[datetime] = None,
  116. name: Optional[str] = None,
  117. ) -> Scenario:
  118. _task_manager = _TaskManagerFactory._build_manager()
  119. _data_manager = _DataManagerFactory._build_manager()
  120. scenario_id = Scenario._new_id(str(config.id))
  121. cycle = (
  122. _CycleManagerFactory._build_manager()._get_or_create(config.frequency, creation_date)
  123. if config.frequency
  124. else None
  125. )
  126. cycle_id = cycle.id if cycle else None
  127. tasks = (
  128. _task_manager._bulk_get_or_create(config.task_configs, cycle_id, scenario_id) if config.task_configs else []
  129. )
  130. additional_data_nodes = (
  131. _data_manager._bulk_get_or_create(config.additional_data_node_configs, cycle_id, scenario_id)
  132. if config.additional_data_node_configs
  133. else {}
  134. )
  135. sequences = {}
  136. tasks_and_config_id_maps = {task.config_id: task for task in tasks}
  137. for sequence_name, sequence_task_configs in config.sequences.items():
  138. sequence_tasks = []
  139. non_existing_sequence_task_config_in_scenario_config = set()
  140. for sequence_task_config in sequence_task_configs:
  141. if task := tasks_and_config_id_maps.get(sequence_task_config.id):
  142. sequence_tasks.append(task)
  143. else:
  144. non_existing_sequence_task_config_in_scenario_config.add(sequence_task_config.id)
  145. if non_existing_sequence_task_config_in_scenario_config:
  146. raise SequenceTaskConfigDoesNotExistInSameScenarioConfig(
  147. list(non_existing_sequence_task_config_in_scenario_config), sequence_name, str(config.id)
  148. )
  149. sequences[sequence_name] = {Scenario._SEQUENCE_TASKS_KEY: sequence_tasks}
  150. is_primary_scenario = len(cls._get_all_by_cycle(cycle)) == 0 if cycle else False
  151. props = config._properties.copy()
  152. if name:
  153. props["name"] = name
  154. version = cls._get_latest_version()
  155. scenario = Scenario(
  156. config_id=str(config.id),
  157. tasks=set(tasks),
  158. properties=props,
  159. additional_data_nodes=set(additional_data_nodes.values()),
  160. scenario_id=scenario_id,
  161. creation_date=creation_date,
  162. is_primary=is_primary_scenario,
  163. cycle=cycle,
  164. version=version,
  165. sequences=sequences,
  166. )
  167. for task in tasks:
  168. if scenario_id not in task._parent_ids:
  169. task._parent_ids.update([scenario_id])
  170. _task_manager._update(task)
  171. for dn in additional_data_nodes.values():
  172. if scenario_id not in dn._parent_ids:
  173. dn._parent_ids.update([scenario_id])
  174. _data_manager._update(dn)
  175. cls._repository._save(scenario)
  176. if not scenario._is_consistent():
  177. raise InvalidScenario(scenario.id)
  178. from ..sequence._sequence_manager_factory import _SequenceManagerFactory
  179. _SequenceManagerFactory._build_manager()._bulk_create_from_scenario(scenario)
  180. Notifier.publish(_make_event(scenario, EventOperation.CREATION))
  181. return scenario
  182. @classmethod
  183. def _is_submittable(cls, scenario: Union[Scenario, ScenarioId]) -> ReasonCollection:
  184. reason_collector = ReasonCollection()
  185. if isinstance(scenario, str):
  186. scenario_id = scenario
  187. scenario = cls._get(scenario)
  188. if scenario is None:
  189. reason_collector._add_reason(scenario_id, EntityDoesNotExist(scenario_id))
  190. return reason_collector
  191. if not isinstance(scenario, Scenario):
  192. reason_collector._add_reason(str(scenario), EntityIsNotSubmittableEntity(str(scenario)))
  193. else:
  194. return scenario.is_ready_to_run()
  195. return reason_collector
  196. @classmethod
  197. def _submit(
  198. cls,
  199. scenario: Union[Scenario, ScenarioId],
  200. callbacks: Optional[List[Callable]] = None,
  201. force: bool = False,
  202. wait: bool = False,
  203. timeout: Union[float, int, None] = None,
  204. check_inputs_are_ready: bool = True,
  205. **properties,
  206. ) -> Submission:
  207. scenario_id = scenario.id if isinstance(scenario, Scenario) else scenario
  208. if not isinstance(scenario, Scenario):
  209. scenario = cls._get(scenario_id)
  210. if scenario is None or not cls._exists(scenario_id):
  211. raise NonExistingScenario(scenario_id)
  212. callbacks = callbacks or []
  213. scenario_subscription_callback = cls.__get_status_notifier_callbacks(scenario) + callbacks
  214. if check_inputs_are_ready:
  215. _warn_if_inputs_not_ready(scenario.get_inputs())
  216. submission = (
  217. _TaskManagerFactory._build_manager()
  218. ._orchestrator()
  219. .submit(
  220. scenario,
  221. callbacks=scenario_subscription_callback,
  222. force=force,
  223. wait=wait,
  224. timeout=timeout,
  225. **properties,
  226. )
  227. )
  228. Notifier.publish(_make_event(scenario, EventOperation.SUBMISSION))
  229. return submission
  230. @classmethod
  231. def __get_status_notifier_callbacks(cls, scenario: Scenario) -> List:
  232. return [partial(c.callback, *c.params, scenario) for c in scenario.subscribers]
  233. @classmethod
  234. def _get_primary(cls, cycle: Cycle) -> Optional[Scenario]:
  235. scenarios = cls._get_all_by_cycle(cycle)
  236. for scenario in scenarios:
  237. if scenario.is_primary:
  238. return scenario
  239. return None
  240. @classmethod
  241. def _get_all_by_cycle_tag(cls, cycle: Cycle, tag: str) -> List[Scenario]:
  242. cycles_scenarios = cls._get_all_by_cycle(cycle)
  243. return [scenario for scenario in cycles_scenarios if scenario.has_tag(tag)]
  244. @classmethod
  245. def _get_all_by_tag(cls, tag: str) -> List[Scenario]:
  246. return [scenario for scenario in cls._get_all() if scenario.has_tag(tag)]
  247. @classmethod
  248. def _get_all_by_cycle(cls, cycle: Cycle) -> List[Scenario]:
  249. filters = cls._build_filters_with_version("all")
  250. if not filters:
  251. filters = [{}]
  252. for fil in filters:
  253. fil.update({"cycle": cycle.id})
  254. return cls._get_all_by(filters)
  255. @classmethod
  256. def _get_primary_scenarios(cls) -> List[Scenario]:
  257. return [scenario for scenario in cls._get_all() if scenario.is_primary]
  258. @staticmethod
  259. def _sort_scenarios(
  260. scenarios: List[Scenario],
  261. descending: bool = False,
  262. sort_key: Literal["name", "id", "config_id", "creation_date", "tags"] = "name",
  263. ) -> List[Scenario]:
  264. if sort_key in ["name", "config_id", "creation_date", "tags"]:
  265. if sort_key == "tags":
  266. scenarios.sort(key=lambda x: (tuple(sorted(x.tags)), x.id), reverse=descending)
  267. else:
  268. scenarios.sort(key=lambda x: (getattr(x, sort_key), x.id), reverse=descending)
  269. elif sort_key == "id":
  270. scenarios.sort(key=lambda x: x.id, reverse=descending)
  271. else:
  272. scenarios.sort(key=lambda x: (x.name, x.id), reverse=descending)
  273. return scenarios
  274. @staticmethod
  275. def _filter_by_creation_time(
  276. scenarios: List[Scenario],
  277. created_start_time: Optional[datetime] = None,
  278. created_end_time: Optional[datetime] = None,
  279. ) -> List[Scenario]:
  280. """
  281. Filter a list of scenarios by a given creation time period.
  282. Arguments:
  283. created_start_time (Optional[datetime]): Start time of the period. The start time is inclusive.
  284. created_end_time (Optional[datetime]): End time of the period. The end time is exclusive.
  285. Returns:
  286. List[Scenario]: List of scenarios created in the given time period.
  287. """
  288. if not created_start_time and not created_end_time:
  289. return scenarios
  290. if not created_start_time:
  291. return [scenario for scenario in scenarios if scenario.creation_date < created_end_time]
  292. if not created_end_time:
  293. return [scenario for scenario in scenarios if created_start_time <= scenario.creation_date]
  294. return [scenario for scenario in scenarios if created_start_time <= scenario.creation_date < created_end_time]
  295. @classmethod
  296. def _is_promotable_to_primary(cls, scenario: Union[Scenario, ScenarioId]) -> ReasonCollection:
  297. reason_collection = ReasonCollection()
  298. if isinstance(scenario, str):
  299. scenario_id = scenario
  300. scenario = cls._get(scenario_id)
  301. else:
  302. scenario_id = scenario.id
  303. if not scenario:
  304. reason_collection._add_reason(scenario_id, EntityDoesNotExist(scenario_id))
  305. else:
  306. if scenario.is_primary:
  307. reason_collection._add_reason(scenario_id, ScenarioIsThePrimaryScenario(scenario_id, scenario.cycle.id))
  308. if not scenario.cycle:
  309. reason_collection._add_reason(scenario_id, ScenarioDoesNotBelongToACycle(scenario_id))
  310. return reason_collection
  311. @classmethod
  312. def _set_primary(cls, scenario: Scenario) -> None:
  313. if not scenario.cycle:
  314. raise DoesNotBelongToACycle(
  315. f"Can't set scenario {scenario.id} to primary because it doesn't belong to a cycle."
  316. )
  317. primary_scenario = cls._get_primary(scenario.cycle)
  318. # To prevent SAME scenario updating out of Context Manager
  319. if primary_scenario and primary_scenario != scenario:
  320. primary_scenario.is_primary = False # type: ignore
  321. scenario.is_primary = True # type: ignore
  322. @classmethod
  323. def _tag(cls, scenario: Scenario, tag: str) -> None:
  324. tags = scenario.properties.get(cls._AUTHORIZED_TAGS_KEY, set())
  325. if len(tags) > 0 and tag not in tags:
  326. raise UnauthorizedTagError(f"Tag `{tag}` not authorized by scenario configuration `{scenario.config_id}`")
  327. scenario._add_tag(tag)
  328. cls._update(scenario)
  329. Notifier.publish(
  330. _make_event(scenario, EventOperation.UPDATE, attribute_name="tags", attribute_value=scenario.tags)
  331. )
  332. @classmethod
  333. def _untag(cls, scenario: Scenario, tag: str) -> None:
  334. scenario._remove_tag(tag)
  335. cls._update(scenario)
  336. Notifier.publish(
  337. _make_event(scenario, EventOperation.UPDATE, attribute_name="tags", attribute_value=scenario.tags)
  338. )
  339. @classmethod
  340. def _compare(cls, *scenarios: Scenario, data_node_config_id: Optional[str] = None) -> Dict:
  341. if len(scenarios) < 2:
  342. raise InsufficientScenarioToCompare("At least two scenarios are required to compare.")
  343. if not all(scenarios[0].config_id == scenario.config_id for scenario in scenarios):
  344. raise DifferentScenarioConfigs("Scenarios to compare must have the same configuration.")
  345. if scenario_config := cls.__get_config(scenarios[0]):
  346. results = {}
  347. if data_node_config_id:
  348. if data_node_config_id in scenario_config.comparators.keys():
  349. dn_comparators = {data_node_config_id: scenario_config.comparators[data_node_config_id]}
  350. else:
  351. raise NonExistingComparator(f"Data node config {data_node_config_id} has no comparator.")
  352. else:
  353. dn_comparators = scenario_config.comparators
  354. for data_node_config_id, comparators in dn_comparators.items():
  355. data_nodes = [scenario.__getattr__(data_node_config_id).read() for scenario in scenarios]
  356. results[data_node_config_id] = {
  357. comparator.__name__: comparator(*data_nodes) for comparator in comparators
  358. }
  359. return results
  360. else:
  361. raise NonExistingScenarioConfig(scenarios[0].config_id)
  362. @staticmethod
  363. def __get_config(scenario: Scenario):
  364. return Config.scenarios.get(scenario.config_id, None)
  365. @classmethod
  366. def _is_deletable(cls, scenario: Union[Scenario, ScenarioId]) -> ReasonCollection:
  367. reason_collection = ReasonCollection()
  368. if isinstance(scenario, str):
  369. scenario_id = scenario
  370. scenario = cls._get(scenario)
  371. if scenario is None:
  372. reason_collection._add_reason(scenario_id, EntityDoesNotExist(scenario_id))
  373. return reason_collection
  374. if not isinstance(scenario, Scenario):
  375. reason_collection._add_reason(str(scenario), EntityIsNotAScenario(str(scenario)))
  376. elif scenario.is_primary:
  377. if len(cls._get_all_by_cycle(scenario.cycle)) > 1:
  378. reason_collection._add_reason(scenario.id, ScenarioIsThePrimaryScenario(scenario.id, scenario.cycle.id))
  379. return reason_collection
  380. @classmethod
  381. def _delete(cls, scenario_id: ScenarioId) -> None:
  382. scenario = cls._get(scenario_id)
  383. if not cls._is_deletable(scenario):
  384. raise DeletingPrimaryScenario(
  385. f"Scenario {scenario.id}, which has config id {scenario.config_id}, is primary and there are "
  386. f"other scenarios in the same cycle. "
  387. )
  388. if scenario.is_primary:
  389. _CycleManagerFactory._build_manager()._delete(scenario.cycle.id)
  390. super()._delete(scenario_id)
  391. @classmethod
  392. def _hard_delete(cls, scenario_id: ScenarioId) -> None:
  393. scenario = cls._get(scenario_id)
  394. if not cls._is_deletable(scenario):
  395. raise DeletingPrimaryScenario(
  396. f"Scenario {scenario.id}, which has config id {scenario.config_id}, is primary and there are "
  397. f"other scenarios in the same cycle. "
  398. )
  399. if scenario.is_primary:
  400. _CycleManagerFactory._build_manager()._hard_delete(scenario.cycle.id)
  401. else:
  402. entity_ids_to_delete = cls._get_children_entity_ids(scenario)
  403. entity_ids_to_delete.scenario_ids.add(scenario.id)
  404. cls._delete_entities_of_multiple_types(entity_ids_to_delete)
  405. @classmethod
  406. def _delete_by_version(cls, version_number: str) -> None:
  407. """
  408. Deletes scenario by the version number.
  409. Check if the cycle is only attached to this scenario, then delete it.
  410. """
  411. for scenario in cls._repository._search("version", version_number):
  412. if scenario.cycle and len(cls._get_all_by_cycle(scenario.cycle)) == 1:
  413. _CycleManagerFactory._build_manager()._delete(scenario.cycle.id)
  414. super()._delete(scenario.id)
  415. @classmethod
  416. def _get_children_entity_ids(cls, scenario: Scenario) -> _EntityIds:
  417. entity_ids = _EntityIds()
  418. for sequence in scenario.sequences.values():
  419. if sequence.owner_id == scenario.id:
  420. entity_ids.sequence_ids.add(sequence.id)
  421. for task in scenario.tasks.values():
  422. if task.owner_id == scenario.id:
  423. entity_ids.task_ids.add(task.id)
  424. for data_node in scenario.data_nodes.values():
  425. if data_node.owner_id == scenario.id:
  426. entity_ids.data_node_ids.add(data_node.id)
  427. jobs = _JobManagerFactory._build_manager()._get_all()
  428. for job in jobs:
  429. if job.task.id in entity_ids.task_ids:
  430. entity_ids.job_ids.add(job.id)
  431. submissions = _SubmissionManagerFactory._build_manager()._get_all()
  432. submitted_entity_ids = list(entity_ids.scenario_ids.union(entity_ids.sequence_ids, entity_ids.task_ids))
  433. for submission in submissions:
  434. if submission.entity_id in submitted_entity_ids or submission.entity_id == scenario.id:
  435. entity_ids.submission_ids.add(submission.id)
  436. return entity_ids
  437. @classmethod
  438. def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) -> List[Scenario]:
  439. """
  440. Get all scenarios by its config id.
  441. """
  442. filters = cls._build_filters_with_version(version_number)
  443. if not filters:
  444. filters = [{}]
  445. for fil in filters:
  446. fil.update({"config_id": config_id})
  447. return cls._repository._load_all(filters)
  448. @classmethod
  449. def _duplicate(
  450. cls,
  451. scenario: Scenario,
  452. new_creation_date: Optional[datetime] = None,
  453. new_name: Optional[str] = None,
  454. data_to_duplicate: Union[bool, Set[str]] = True,
  455. ) -> Scenario:
  456. """Create a duplicated scenario with its related entities.
  457. Duplicate a scenario, publish a creation event and return the newly created
  458. scenario.
  459. Arguments:
  460. scenario (Scenario): The scenario to duplicate.
  461. new_creation_date (Optional[datetime]): The creation date of the new scenario.
  462. If not provided, the current date and time is used.
  463. new_name (Optional[str]): The name of the new scenario. If not provided, the
  464. name of the original scenario is used.
  465. data_to_duplicate (Union[Set[str], bool]): A set of data node configuration ids used
  466. to duplicate only the data nodes' data with the specified configuration ids.
  467. If True, all data nodes are duplicated. If False, no data nodes are duplicated.
  468. Returns:
  469. The newly created scenario.
  470. """
  471. reasons = cls._can_duplicate(scenario)
  472. if not reasons:
  473. raise Exception(reasons.reasons)
  474. return _ScenarioDuplicator(scenario, data_to_duplicate).duplicate(new_creation_date, new_name)
  475. @classmethod
  476. def _can_duplicate(cls, scenario: Union[str, Scenario]) -> ReasonCollection:
  477. reason_collector = ReasonCollection()
  478. if isinstance(scenario, Scenario):
  479. scenario_id = scenario.id
  480. else:
  481. scenario_id = str(scenario) # type: ignore
  482. if not cls._repository._exists(scenario_id):
  483. reason_collector._add_reason(scenario_id, EntityDoesNotExist(scenario_id))
  484. return reason_collector