sequence.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. from __future__ import annotations
  12. from typing import Any, Callable, Dict, List, Optional, Set, Union
  13. import networkx as nx
  14. from taipy.config.common._template_handler import _TemplateHandler as _tpl
  15. from taipy.config.common._validate_id import _validate_id
  16. from .._entity._entity import _Entity
  17. from .._entity._labeled import _Labeled
  18. from .._entity._properties import _Properties
  19. from .._entity._reload import _Reloader, _self_reload, _self_setter
  20. from .._entity.submittable import Submittable
  21. from .._version._version_manager_factory import _VersionManagerFactory
  22. from ..common._listattributes import _ListAttributes
  23. from ..common._utils import _Subscriber
  24. from ..data.data_node import DataNode
  25. from ..exceptions.exceptions import NonExistingTask
  26. from ..job.job import Job
  27. from ..notification.event import Event, EventEntityType, EventOperation, _make_event
  28. from ..submission.submission import Submission
  29. from ..task.task import Task
  30. from ..task.task_id import TaskId
  31. from .sequence_id import SequenceId
  32. class Sequence(_Entity, Submittable, _Labeled):
  33. """List of `Task^`s and additional attributes representing a set of data processing
  34. elements connected as a direct acyclic graph.
  35. Attributes:
  36. properties (dict[str, Any]): A dictionary of additional properties.
  37. tasks (List[Task^]): The list of `Task`s.
  38. sequence_id (str): The Unique identifier of the sequence.
  39. owner_id (str): The identifier of the owner (scenario_id, cycle_id) or None.
  40. parent_ids (Optional[Set[str]]): The set of identifiers of the parent scenarios.
  41. version (str): The string indicates the application version of the sequence to instantiate. If not provided,
  42. the latest version is used.
  43. """
  44. _ID_PREFIX = "SEQUENCE"
  45. _SEPARATOR = "_"
  46. _MANAGER_NAME = "sequence"
  47. def __init__(
  48. self,
  49. properties: Dict[str, Any],
  50. tasks: Union[List[TaskId], List[Task], List[Union[TaskId, Task]]],
  51. sequence_id: SequenceId,
  52. owner_id: Optional[str] = None,
  53. parent_ids: Optional[Set[str]] = None,
  54. subscribers: Optional[List[_Subscriber]] = None,
  55. version: Optional[str] = None,
  56. ):
  57. super().__init__(subscribers)
  58. self.id: SequenceId = sequence_id
  59. self._tasks = tasks
  60. self._owner_id = owner_id
  61. self._parent_ids = parent_ids or set()
  62. self._properties = _Properties(self, **properties)
  63. self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
  64. @staticmethod
  65. def _new_id(sequence_name: str, scenario_id) -> SequenceId:
  66. seq_id = sequence_name.replace(" ", "TPSPACE")
  67. return SequenceId(Sequence._SEPARATOR.join([Sequence._ID_PREFIX, _validate_id(seq_id), scenario_id]))
  68. def __hash__(self):
  69. return hash(self.id)
  70. def __eq__(self, other):
  71. return self.id == other.id
  72. def __getattr__(self, attribute_name):
  73. protected_attribute_name = _validate_id(attribute_name)
  74. if protected_attribute_name in self._properties:
  75. return _tpl._replace_templates(self._properties[protected_attribute_name])
  76. tasks = self._get_tasks()
  77. if protected_attribute_name in tasks:
  78. return tasks[protected_attribute_name]
  79. for task in tasks.values():
  80. if protected_attribute_name in task.input:
  81. return task.input[protected_attribute_name]
  82. if protected_attribute_name in task.output:
  83. return task.output[protected_attribute_name]
  84. raise AttributeError(f"{attribute_name} is not an attribute of sequence {self.id}")
  85. @property # type: ignore
  86. @_self_reload(_MANAGER_NAME)
  87. def tasks(self) -> Dict[str, Task]:
  88. return self._get_tasks()
  89. @tasks.setter # type: ignore
  90. @_self_setter(_MANAGER_NAME)
  91. def tasks(self, tasks: Union[List[TaskId], List[Task]]):
  92. self._tasks = tasks
  93. @property
  94. def data_nodes(self) -> Dict[str, DataNode]:
  95. data_nodes = {}
  96. list_data_nodes = [task.data_nodes for task in self._get_tasks().values()]
  97. for data_node in list_data_nodes:
  98. for k, v in data_node.items():
  99. data_nodes[k] = v
  100. return data_nodes
  101. @property
  102. def parent_ids(self):
  103. return self._parent_ids
  104. @property
  105. def owner_id(self):
  106. return self._owner_id
  107. @property
  108. def version(self):
  109. return self._version
  110. @property
  111. def properties(self):
  112. self._properties = _Reloader()._reload("sequence", self)._properties
  113. return self._properties
  114. def _is_consistent(self) -> bool:
  115. dag = self._build_dag()
  116. if dag.number_of_nodes() == 0:
  117. return True
  118. if not nx.is_directed_acyclic_graph(dag):
  119. return False
  120. if not nx.is_weakly_connected(dag):
  121. return False
  122. for left_node, right_node in dag.edges:
  123. if (isinstance(left_node, DataNode) and isinstance(right_node, Task)) or (
  124. isinstance(left_node, Task) and isinstance(right_node, DataNode)
  125. ):
  126. continue
  127. return False
  128. return True
  129. def _get_tasks(self) -> Dict[str, Task]:
  130. from ..task._task_manager_factory import _TaskManagerFactory
  131. tasks = {}
  132. task_manager = _TaskManagerFactory._build_manager()
  133. for task_or_id in self._tasks:
  134. t = task_manager._get(task_or_id, task_or_id)
  135. if not isinstance(t, Task):
  136. raise NonExistingTask(task_or_id)
  137. tasks[t.config_id] = t
  138. return tasks
  139. def _get_set_of_tasks(self) -> Set[Task]:
  140. from ..task._task_manager_factory import _TaskManagerFactory
  141. tasks = set()
  142. task_manager = _TaskManagerFactory._build_manager()
  143. for task_or_id in self._tasks:
  144. task = task_manager._get(task_or_id, task_or_id)
  145. if not isinstance(task, Task):
  146. raise NonExistingTask(task_or_id)
  147. tasks.add(task)
  148. return tasks
  149. @property # type: ignore
  150. @_self_reload(_MANAGER_NAME)
  151. def subscribers(self):
  152. return self._subscribers
  153. @subscribers.setter # type: ignore
  154. @_self_setter(_MANAGER_NAME)
  155. def subscribers(self, val):
  156. self._subscribers = _ListAttributes(self, val)
  157. def get_parents(self):
  158. """Get parents of the sequence entity"""
  159. from ... import core as tp
  160. return tp.get_parents(self)
  161. def subscribe(
  162. self,
  163. callback: Callable[[Sequence, Job], None],
  164. params: Optional[List[Any]] = None,
  165. ):
  166. """Subscribe a function to be called on `Job^` status change.
  167. The subscription is applied to all jobs created from the sequence's execution.
  168. Parameters:
  169. callback (Callable[[Sequence^, Job^], None]): The callable function to be called on
  170. status change.
  171. params (Optional[List[Any]]): The parameters to be passed to the _callback_.
  172. Note:
  173. Notification will be available only for jobs created after this subscription.
  174. """
  175. from ... import core as tp
  176. return tp.subscribe_sequence(callback, params, self)
  177. def unsubscribe(self, callback: Callable[[Sequence, Job], None], params: Optional[List[Any]] = None):
  178. """Unsubscribe a function that is called when the status of a `Job^` changes.
  179. Parameters:
  180. callback (Callable[[Sequence^, Job^], None]): The callable function to unsubscribe.
  181. params (Optional[List[Any]]): The parameters to be passed to the _callback_.
  182. Note:
  183. The function will continue to be called for ongoing jobs.
  184. """
  185. from ... import core as tp
  186. return tp.unsubscribe_sequence(callback, params, self)
  187. def submit(
  188. self,
  189. callbacks: Optional[List[Callable]] = None,
  190. force: bool = False,
  191. wait: bool = False,
  192. timeout: Optional[Union[float, int]] = None,
  193. **properties,
  194. ) -> Submission:
  195. """Submit the sequence for execution.
  196. All the `Task^`s of the sequence will be submitted for execution.
  197. Parameters:
  198. callbacks (List[Callable]): The list of callable functions to be called on status
  199. change.
  200. force (bool): Force execution even if the data nodes are in cache.
  201. wait (bool): Wait for the orchestrated jobs created from the sequence submission to be finished
  202. in asynchronous mode.
  203. timeout (Union[float, int]): The maximum number of seconds to wait for the jobs to be finished before
  204. returning.
  205. **properties (dict[str, any]): A keyworded variable length list of additional arguments.
  206. Returns:
  207. A `Submission^` containing the information of the submission.
  208. """
  209. from ._sequence_manager_factory import _SequenceManagerFactory
  210. return _SequenceManagerFactory._build_manager()._submit(self, callbacks, force, wait, timeout, **properties)
  211. def get_label(self) -> str:
  212. """Returns the sequence simple label prefixed by its owner label.
  213. Returns:
  214. The label of the sequence as a string.
  215. """
  216. return self._get_label()
  217. def get_simple_label(self) -> str:
  218. """Returns the sequence simple label.
  219. Returns:
  220. The simple label of the sequence as a string.
  221. """
  222. return self._get_simple_label()
  223. @_make_event.register(Sequence)
  224. def _make_event_for_sequence(
  225. sequence: Sequence,
  226. operation: EventOperation,
  227. /,
  228. attribute_name: Optional[str] = None,
  229. attribute_value: Optional[Any] = None,
  230. **kwargs,
  231. ) -> Event:
  232. metadata = {**kwargs}
  233. return Event(
  234. entity_type=EventEntityType.SEQUENCE,
  235. entity_id=sequence.id,
  236. operation=operation,
  237. attribute_name=attribute_name,
  238. attribute_value=attribute_value,
  239. metadata=metadata,
  240. )