task.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import uuid
  12. from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union
  13. from taipy.config.common._template_handler import _TemplateHandler as _tpl
  14. from taipy.config.common._validate_id import _validate_id
  15. from taipy.config.common.scope import Scope
  16. from .._entity._entity import _Entity
  17. from .._entity._labeled import _Labeled
  18. from .._entity._properties import _Properties
  19. from .._entity._reload import _Reloader, _self_reload, _self_setter
  20. from .._version._version_manager_factory import _VersionManagerFactory
  21. from ..data.data_node import DataNode
  22. from ..notification.event import Event, EventEntityType, EventOperation, _make_event
  23. from ..submission.submission import Submission
  24. from .task_id import TaskId
  25. class Task(_Entity, _Labeled):
  26. """Hold a user function that will be executed, its parameters and the results.
  27. A `Task` brings together the user code as function, the inputs and the outputs as data nodes
  28. (instances of the `DataNode^` class).
  29. Attributes:
  30. config_id (str): The identifier of the `TaskConfig^`.
  31. properties (dict[str, Any]): A dictionary of additional properties.
  32. function (callable): The python function to execute. The _function_ must take as parameter the
  33. data referenced by inputs data nodes, and must return the data referenced by outputs data nodes.
  34. input (Union[DataNode^, List[DataNode^]]): The list of inputs.
  35. output (Union[DataNode^, List[DataNode^]]): The list of outputs.
  36. id (str): The unique identifier of the task.
  37. owner_id (str): The identifier of the owner (sequence_id, scenario_id, cycle_id) or None.
  38. parent_ids (Optional[Set[str]]): The set of identifiers of the parent sequences.
  39. version (str): The string indicates the application version of the task to instantiate. If not provided, the
  40. latest version is used.
  41. skippable (bool): If True, indicates that the task can be skipped if no change has been made on inputs. The
  42. default value is _False_.
  43. """
  44. _ID_PREFIX = "TASK"
  45. __ID_SEPARATOR = "_"
  46. _MANAGER_NAME = "task"
  47. def __init__(
  48. self,
  49. config_id: str,
  50. properties: Dict[str, Any],
  51. function,
  52. input: Optional[Iterable[DataNode]] = None,
  53. output: Optional[Iterable[DataNode]] = None,
  54. id: Optional[TaskId] = None,
  55. owner_id: Optional[str] = None,
  56. parent_ids: Optional[Set[str]] = None,
  57. version: Optional[str] = None,
  58. skippable: bool = False,
  59. ):
  60. self._config_id = _validate_id(config_id)
  61. self.id = id or TaskId(self.__ID_SEPARATOR.join([self._ID_PREFIX, self.config_id, str(uuid.uuid4())]))
  62. self._owner_id = owner_id
  63. self._parent_ids = parent_ids or set()
  64. self.__input = {dn.config_id: dn for dn in input or []}
  65. self.__output = {dn.config_id: dn for dn in output or []}
  66. self._function = function
  67. self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
  68. self._skippable = skippable
  69. self._properties = _Properties(self, **properties)
  70. def __hash__(self):
  71. return hash(self.id)
  72. def __eq__(self, other):
  73. return self.id == other.id
  74. def __getstate__(self):
  75. return vars(self)
  76. def __setstate__(self, state):
  77. vars(self).update(state)
  78. def __getattr__(self, attribute_name):
  79. protected_attribute_name = _validate_id(attribute_name)
  80. if protected_attribute_name in self._properties:
  81. return _tpl._replace_templates(self._properties[protected_attribute_name])
  82. if protected_attribute_name in self.input:
  83. return self.input[protected_attribute_name]
  84. if protected_attribute_name in self.output:
  85. return self.output[protected_attribute_name]
  86. raise AttributeError(f"{attribute_name} is not an attribute of task {self.id}")
  87. @property
  88. def properties(self):
  89. self._properties = _Reloader()._reload(self._MANAGER_NAME, self)._properties
  90. return self._properties
  91. @property
  92. def config_id(self):
  93. return self._config_id
  94. @property
  95. def owner_id(self):
  96. return self._owner_id
  97. def get_parents(self):
  98. """Get parents of the task."""
  99. from ... import core as tp
  100. return tp.get_parents(self)
  101. @property # type: ignore
  102. @_self_reload(_MANAGER_NAME)
  103. def parent_ids(self):
  104. return self._parent_ids
  105. @property
  106. def input(self) -> Dict[str, DataNode]:
  107. return self.__input
  108. @property
  109. def output(self) -> Dict[str, DataNode]:
  110. return self.__output
  111. @property
  112. def data_nodes(self) -> Dict[str, DataNode]:
  113. return {**self.input, **self.output}
  114. @property # type: ignore
  115. @_self_reload(_MANAGER_NAME)
  116. def function(self):
  117. return self._function
  118. @function.setter # type: ignore
  119. @_self_setter(_MANAGER_NAME)
  120. def function(self, val):
  121. self._function = val
  122. @property # type: ignore
  123. @_self_reload(_MANAGER_NAME)
  124. def skippable(self):
  125. return self._skippable
  126. @skippable.setter # type: ignore
  127. @_self_setter(_MANAGER_NAME)
  128. def skippable(self, val):
  129. self._skippable = val
  130. @property
  131. def scope(self) -> Scope:
  132. """Retrieve the lowest scope of the task based on its data nodes.
  133. Returns:
  134. The lowest scope present in input and output data nodes or GLOBAL if there are
  135. either no input or no output.
  136. """
  137. data_nodes = list(self.__input.values()) + list(self.__output.values())
  138. return Scope(min(dn.scope for dn in data_nodes)) if len(data_nodes) != 0 else Scope.GLOBAL
  139. @property
  140. def version(self):
  141. return self._version
  142. def submit(
  143. self,
  144. callbacks: Optional[List[Callable]] = None,
  145. force: bool = False,
  146. wait: bool = False,
  147. timeout: Optional[Union[float, int]] = None,
  148. **properties,
  149. ) -> Submission:
  150. """Submit the task for execution.
  151. Parameters:
  152. callbacks (List[Callable]): The list of callable functions to be called on status
  153. change.
  154. force (bool): Force execution even if the data nodes are in cache.
  155. wait (bool): Wait for the orchestrated job created from the task submission to be finished in asynchronous
  156. mode.
  157. timeout (Union[float, int]): The maximum number of seconds to wait for the job to be finished before
  158. returning.
  159. **properties (dict[str, any]): A keyworded variable length list of additional arguments.
  160. Returns:
  161. A `Submission^` containing the information of the submission.
  162. """
  163. from ._task_manager_factory import _TaskManagerFactory
  164. return _TaskManagerFactory._build_manager()._submit(self, callbacks, force, wait, timeout, **properties)
  165. def get_label(self) -> str:
  166. """Returns the task simple label prefixed by its owner label.
  167. Returns:
  168. The label of the task as a string.
  169. """
  170. return self._get_label()
  171. def get_simple_label(self) -> str:
  172. """Returns the task simple label.
  173. Returns:
  174. The simple label of the task as a string.
  175. """
  176. return self._get_simple_label()
  177. @_make_event.register(Task)
  178. def _make_event_for_task(
  179. task: Task,
  180. operation: EventOperation,
  181. /,
  182. attribute_name: Optional[str] = None,
  183. attribute_value: Optional[Any] = None,
  184. **kwargs,
  185. ) -> Event:
  186. metadata = {"version": task.version, "config_id": task.config_id, **kwargs}
  187. return Event(
  188. entity_type=EventEntityType.TASK,
  189. entity_id=task.id,
  190. operation=operation,
  191. attribute_name=attribute_name,
  192. attribute_value=attribute_value,
  193. metadata=metadata,
  194. )