sequence.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright 2023 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. from __future__ import annotations
  12. from typing import Any, Callable, Dict, List, Optional, Set, Union
  13. import networkx as nx
  14. from taipy.config.common._template_handler import _TemplateHandler as _tpl
  15. from taipy.config.common._validate_id import _validate_id
  16. from .._entity._entity import _Entity
  17. from .._entity._labeled import _Labeled
  18. from .._entity._properties import _Properties
  19. from .._entity._reload import _Reloader, _self_reload, _self_setter
  20. from .._entity.submittable import Submittable
  21. from .._version._version_manager_factory import _VersionManagerFactory
  22. from ..common._listattributes import _ListAttributes
  23. from ..common._utils import _Subscriber
  24. from ..data.data_node import DataNode
  25. from ..exceptions.exceptions import NonExistingTask
  26. from ..job.job import Job
  27. from ..notification.event import Event, EventEntityType, EventOperation, _make_event
  28. from ..task.task import Task
  29. from ..task.task_id import TaskId
  30. from .sequence_id import SequenceId
  31. class Sequence(_Entity, Submittable, _Labeled):
  32. """List of `Task^`s and additional attributes representing a set of data processing
  33. elements connected as a direct acyclic graph.
  34. Attributes:
  35. properties (dict[str, Any]): A dictionary of additional properties.
  36. tasks (List[Task^]): The list of `Task`s.
  37. sequence_id (str): The Unique identifier of the sequence.
  38. owner_id (str): The identifier of the owner (scenario_id, cycle_id) or None.
  39. parent_ids (Optional[Set[str]]): The set of identifiers of the parent scenarios.
  40. version (str): The string indicates the application version of the sequence to instantiate. If not provided,
  41. the latest version is used.
  42. """
  43. _ID_PREFIX = "SEQUENCE"
  44. _SEPARATOR = "_"
  45. _MANAGER_NAME = "sequence"
  46. def __init__(
  47. self,
  48. properties: Dict[str, Any],
  49. tasks: Union[List[TaskId], List[Task], List[Union[TaskId, Task]]],
  50. sequence_id: SequenceId,
  51. owner_id: Optional[str] = None,
  52. parent_ids: Optional[Set[str]] = None,
  53. subscribers: Optional[List[_Subscriber]] = None,
  54. version: Optional[str] = None,
  55. ):
  56. super().__init__(subscribers)
  57. self.id: SequenceId = sequence_id
  58. self._tasks = tasks
  59. self._owner_id = owner_id
  60. self._parent_ids = parent_ids or set()
  61. self._properties = _Properties(self, **properties)
  62. self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
  63. @staticmethod
  64. def _new_id(sequence_name: str, scenario_id) -> SequenceId:
  65. return SequenceId(Sequence._SEPARATOR.join([Sequence._ID_PREFIX, _validate_id(sequence_name), scenario_id]))
  66. def __hash__(self):
  67. return hash(self.id)
  68. def __eq__(self, other):
  69. return self.id == other.id
  70. def __getattr__(self, attribute_name):
  71. protected_attribute_name = _validate_id(attribute_name)
  72. if protected_attribute_name in self._properties:
  73. return _tpl._replace_templates(self._properties[protected_attribute_name])
  74. tasks = self._get_tasks()
  75. if protected_attribute_name in tasks:
  76. return tasks[protected_attribute_name]
  77. for task in tasks.values():
  78. if protected_attribute_name in task.input:
  79. return task.input[protected_attribute_name]
  80. if protected_attribute_name in task.output:
  81. return task.output[protected_attribute_name]
  82. raise AttributeError(f"{attribute_name} is not an attribute of sequence {self.id}")
  83. @property # type: ignore
  84. @_self_reload(_MANAGER_NAME)
  85. def tasks(self) -> Dict[str, Task]:
  86. return self._get_tasks()
  87. @tasks.setter # type: ignore
  88. @_self_setter(_MANAGER_NAME)
  89. def tasks(self, tasks: Union[List[TaskId], List[Task]]):
  90. self._tasks = tasks
  91. @property
  92. def data_nodes(self) -> Dict[str, DataNode]:
  93. data_nodes = {}
  94. list_data_nodes = [task.data_nodes for task in self._get_tasks().values()]
  95. for data_node in list_data_nodes:
  96. for k, v in data_node.items():
  97. data_nodes[k] = v
  98. return data_nodes
  99. @property
  100. def parent_ids(self):
  101. return self._parent_ids
  102. @property
  103. def owner_id(self):
  104. return self._owner_id
  105. @property
  106. def version(self):
  107. return self._version
  108. @property
  109. def properties(self):
  110. self._properties = _Reloader()._reload("sequence", self)._properties
  111. return self._properties
  112. def _is_consistent(self) -> bool:
  113. dag = self._build_dag()
  114. if dag.number_of_nodes() == 0:
  115. return True
  116. if not nx.is_directed_acyclic_graph(dag):
  117. return False
  118. if not nx.is_weakly_connected(dag):
  119. return False
  120. for left_node, right_node in dag.edges:
  121. if (isinstance(left_node, DataNode) and isinstance(right_node, Task)) or (
  122. isinstance(left_node, Task) and isinstance(right_node, DataNode)
  123. ):
  124. continue
  125. return False
  126. return True
  127. def _get_tasks(self) -> Dict[str, Task]:
  128. from ..task._task_manager_factory import _TaskManagerFactory
  129. tasks = {}
  130. task_manager = _TaskManagerFactory._build_manager()
  131. for task_or_id in self._tasks:
  132. t = task_manager._get(task_or_id, task_or_id)
  133. if not isinstance(t, Task):
  134. raise NonExistingTask(task_or_id)
  135. tasks[t.config_id] = t
  136. return tasks
  137. def _get_set_of_tasks(self) -> Set[Task]:
  138. from ..task._task_manager_factory import _TaskManagerFactory
  139. tasks = set()
  140. task_manager = _TaskManagerFactory._build_manager()
  141. for task_or_id in self._tasks:
  142. task = task_manager._get(task_or_id, task_or_id)
  143. if not isinstance(task, Task):
  144. raise NonExistingTask(task_or_id)
  145. tasks.add(task)
  146. return tasks
  147. @property # type: ignore
  148. @_self_reload(_MANAGER_NAME)
  149. def subscribers(self):
  150. return self._subscribers
  151. @subscribers.setter # type: ignore
  152. @_self_setter(_MANAGER_NAME)
  153. def subscribers(self, val):
  154. self._subscribers = _ListAttributes(self, val)
  155. def get_parents(self):
  156. """Get parents of the sequence entity"""
  157. from ... import core as tp
  158. return tp.get_parents(self)
  159. def subscribe(
  160. self,
  161. callback: Callable[[Sequence, Job], None],
  162. params: Optional[List[Any]] = None,
  163. ):
  164. """Subscribe a function to be called on `Job^` status change.
  165. The subscription is applied to all jobs created from the sequence's execution.
  166. Parameters:
  167. callback (Callable[[Sequence^, Job^], None]): The callable function to be called on
  168. status change.
  169. params (Optional[List[Any]]): The parameters to be passed to the _callback_.
  170. Note:
  171. Notification will be available only for jobs created after this subscription.
  172. """
  173. from ... import core as tp
  174. return tp.subscribe_sequence(callback, params, self)
  175. def unsubscribe(self, callback: Callable[[Sequence, Job], None], params: Optional[List[Any]] = None):
  176. """Unsubscribe a function that is called when the status of a `Job^` changes.
  177. Parameters:
  178. callback (Callable[[Sequence^, Job^], None]): The callable function to unsubscribe.
  179. params (Optional[List[Any]]): The parameters to be passed to the _callback_.
  180. Note:
  181. The function will continue to be called for ongoing jobs.
  182. """
  183. from ... import core as tp
  184. return tp.unsubscribe_sequence(callback, params, self)
  185. def submit(
  186. self,
  187. callbacks: Optional[List[Callable]] = None,
  188. force: bool = False,
  189. wait: bool = False,
  190. timeout: Optional[Union[float, int]] = None,
  191. ) -> List[Job]:
  192. """Submit the sequence for execution.
  193. All the `Task^`s of the sequence will be submitted for execution.
  194. Parameters:
  195. callbacks (List[Callable]): The list of callable functions to be called on status
  196. change.
  197. force (bool): Force execution even if the data nodes are in cache.
  198. wait (bool): Wait for the orchestrated jobs created from the sequence submission to be finished
  199. in asynchronous mode.
  200. timeout (Union[float, int]): The maximum number of seconds to wait for the jobs to be finished before
  201. returning.
  202. Returns:
  203. A list of created `Job^`s.
  204. """
  205. from ._sequence_manager_factory import _SequenceManagerFactory
  206. return _SequenceManagerFactory._build_manager()._submit(self, callbacks, force, wait, timeout)
  207. def get_label(self) -> str:
  208. """Returns the sequence simple label prefixed by its owner label.
  209. Returns:
  210. The label of the sequence as a string.
  211. """
  212. return self._get_label()
  213. def get_simple_label(self) -> str:
  214. """Returns the sequence simple label.
  215. Returns:
  216. The simple label of the sequence as a string.
  217. """
  218. return self._get_simple_label()
  219. @_make_event.register(Sequence)
  220. def _make_event_for_sequence(
  221. sequence: Sequence,
  222. operation: EventOperation,
  223. /,
  224. attribute_name: Optional[str] = None,
  225. attribute_value: Optional[Any] = None,
  226. **kwargs,
  227. ) -> Event:
  228. metadata = {**kwargs}
  229. return Event(
  230. entity_type=EventEntityType.SEQUENCE,
  231. entity_id=sequence.id,
  232. operation=operation,
  233. attribute_name=attribute_name,
  234. attribute_value=attribute_value,
  235. metadata=metadata,
  236. )