test_task_manager.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import uuid
  12. from datetime import datetime
  13. from unittest import mock
  14. import pytest
  15. from taipy.config.common.scope import Scope
  16. from taipy.config.config import Config
  17. from taipy.core import taipy
  18. from taipy.core._orchestrator._orchestrator import _Orchestrator
  19. from taipy.core._version._version_manager import _VersionManager
  20. from taipy.core.data._data_manager import _DataManager
  21. from taipy.core.data.in_memory import InMemoryDataNode
  22. from taipy.core.exceptions.exceptions import ModelNotFound, NonExistingTask
  23. from taipy.core.notification._submittable_status_cache import SubmittableStatusCache
  24. from taipy.core.scenario._scenario_manager_factory import _ScenarioManagerFactory
  25. from taipy.core.task._task_manager import _TaskManager
  26. from taipy.core.task._task_manager_factory import _TaskManagerFactory
  27. from taipy.core.task.task import Task
  28. from taipy.core.task.task_id import TaskId
  29. def test_create_and_save():
  30. input_configs = [Config.configure_data_node("my_input", "in_memory")]
  31. output_configs = Config.configure_data_node("my_output", "in_memory")
  32. task_config = Config.configure_task("foo", print, input_configs, output_configs)
  33. task = _create_task_from_config(task_config)
  34. assert task.id is not None
  35. assert task.config_id == "foo"
  36. assert len(task.input) == 1
  37. assert len(_DataManager._get_all()) == 2
  38. assert task.my_input.id is not None
  39. assert task.my_input.config_id == "my_input"
  40. assert task.my_output.id is not None
  41. assert task.my_output.config_id == "my_output"
  42. assert task.function == print
  43. assert task.parent_ids == set()
  44. task_retrieved_from_manager = _TaskManager._get(task.id)
  45. assert task_retrieved_from_manager.id == task.id
  46. assert task_retrieved_from_manager.config_id == task.config_id
  47. assert len(task_retrieved_from_manager.input) == len(task.input)
  48. assert task_retrieved_from_manager.my_input.id is not None
  49. assert task_retrieved_from_manager.my_input.config_id == task.my_input.config_id
  50. assert task_retrieved_from_manager.my_output.id is not None
  51. assert task_retrieved_from_manager.my_output.config_id == task.my_output.config_id
  52. assert task_retrieved_from_manager.function == task.function
  53. assert task_retrieved_from_manager.parent_ids == set()
  54. def test_do_not_recreate_existing_data_node():
  55. input_config = Config.configure_data_node("my_input", "in_memory", scope=Scope.SCENARIO)
  56. output_config = Config.configure_data_node("my_output", "in_memory", scope=Scope.SCENARIO)
  57. _DataManager._create_and_set(input_config, "scenario_id", "task_id")
  58. assert len(_DataManager._get_all()) == 1
  59. task_config = Config.configure_task("foo", print, input_config, output_config)
  60. _create_task_from_config(task_config, scenario_id="scenario_id")
  61. assert len(_DataManager._get_all()) == 2
  62. def test_assign_task_as_parent_of_datanode():
  63. dn_config_1 = Config.configure_data_node("dn_1", "in_memory", scope=Scope.SCENARIO)
  64. dn_config_2 = Config.configure_data_node("dn_2", "in_memory", scope=Scope.SCENARIO)
  65. dn_config_3 = Config.configure_data_node("dn_3", "in_memory", scope=Scope.SCENARIO)
  66. task_config_1 = Config.configure_task("task_1", print, dn_config_1, dn_config_2)
  67. task_config_2 = Config.configure_task("task_2", print, dn_config_2, dn_config_3)
  68. tasks = _TaskManager._bulk_get_or_create([task_config_1, task_config_2], "cycle_id", "scenario_id")
  69. assert len(_DataManager._get_all()) == 3
  70. assert len(_TaskManager._get_all()) == 2
  71. assert len(tasks) == 2
  72. dns = {dn.config_id: dn for dn in _DataManager._get_all()}
  73. assert dns["dn_1"].parent_ids == {tasks[0].id}
  74. assert dns["dn_2"].parent_ids == {tasks[0].id, tasks[1].id}
  75. assert dns["dn_3"].parent_ids == {tasks[1].id}
  76. def test_do_not_recreate_existing_task():
  77. input_config_scope_scenario = Config.configure_data_node("my_input_1", "in_memory", Scope.SCENARIO)
  78. output_config_scope_scenario = Config.configure_data_node("my_output_1", "in_memory", Scope.SCENARIO)
  79. task_config_1 = Config.configure_task("bar", print, input_config_scope_scenario, output_config_scope_scenario)
  80. # task_config_2 scope is Scenario
  81. task_1 = _create_task_from_config(task_config_1)
  82. assert len(_TaskManager._get_all()) == 1
  83. task_2 = _create_task_from_config(task_config_1) # Do not create. It already exists for None scenario
  84. assert len(_TaskManager._get_all()) == 1
  85. assert task_1.id == task_2.id
  86. task_3 = _create_task_from_config(task_config_1, None, None) # Do not create. It already exists for None scenario
  87. assert len(_TaskManager._get_all()) == 1
  88. assert task_1.id == task_2.id
  89. assert task_2.id == task_3.id
  90. task_4 = _create_task_from_config(task_config_1, None, "scenario_1") # Create even if sequence is the same.
  91. assert len(_TaskManager._get_all()) == 2
  92. assert task_1.id == task_2.id
  93. assert task_2.id == task_3.id
  94. assert task_3.id != task_4.id
  95. task_5 = _create_task_from_config(
  96. task_config_1, None, "scenario_1"
  97. ) # Do not create. It already exists for scenario_1
  98. assert len(_TaskManager._get_all()) == 2
  99. assert task_1.id == task_2.id
  100. assert task_2.id == task_3.id
  101. assert task_3.id != task_4.id
  102. assert task_4.id == task_5.id
  103. task_6 = _create_task_from_config(task_config_1, None, "scenario_2")
  104. assert len(_TaskManager._get_all()) == 3
  105. assert task_1.id == task_2.id
  106. assert task_2.id == task_3.id
  107. assert task_3.id != task_4.id
  108. assert task_4.id == task_5.id
  109. assert task_5.id != task_6.id
  110. assert task_3.id != task_6.id
  111. input_config_scope_cycle = Config.configure_data_node("my_input_2", "in_memory", Scope.CYCLE)
  112. output_config_scope_cycle = Config.configure_data_node("my_output_2", "in_memory", Scope.CYCLE)
  113. task_config_2 = Config.configure_task("xyz", print, input_config_scope_cycle, output_config_scope_cycle)
  114. # task_config_3 scope is Cycle
  115. task_7 = _create_task_from_config(task_config_2)
  116. assert len(_TaskManager._get_all()) == 4
  117. task_8 = _create_task_from_config(task_config_2) # Do not create. It already exists for None cycle
  118. assert len(_TaskManager._get_all()) == 4
  119. assert task_7.id == task_8.id
  120. task_9 = _create_task_from_config(task_config_2, None, None) # Do not create. It already exists for None cycle
  121. assert len(_TaskManager._get_all()) == 4
  122. assert task_7.id == task_8.id
  123. assert task_8.id == task_9.id
  124. task_10 = _create_task_from_config(
  125. task_config_2, None, "scenario"
  126. ) # Do not create. It already exists for None cycle
  127. assert len(_TaskManager._get_all()) == 4
  128. assert task_7.id == task_8.id
  129. assert task_8.id == task_9.id
  130. assert task_9.id == task_10.id
  131. task_11 = _create_task_from_config(
  132. task_config_2, None, "scenario"
  133. ) # Do not create. It already exists for None cycle
  134. assert len(_TaskManager._get_all()) == 4
  135. assert task_7.id == task_8.id
  136. assert task_8.id == task_9.id
  137. assert task_9.id == task_10.id
  138. assert task_10.id == task_11.id
  139. task_12 = _create_task_from_config(task_config_2, "cycle", None)
  140. assert len(_TaskManager._get_all()) == 5
  141. assert task_7.id == task_8.id
  142. assert task_8.id == task_9.id
  143. assert task_9.id == task_10.id
  144. assert task_10.id == task_11.id
  145. assert task_11.id != task_12.id
  146. task_13 = _create_task_from_config(task_config_2, "cycle", None)
  147. assert len(_TaskManager._get_all()) == 5
  148. assert task_7.id == task_8.id
  149. assert task_8.id == task_9.id
  150. assert task_9.id == task_10.id
  151. assert task_10.id == task_11.id
  152. assert task_11.id != task_12.id
  153. assert task_12.id == task_13.id
  154. def test_set_and_get_task():
  155. task_id_1 = TaskId("id1")
  156. first_task = Task("name_1", {}, print, [], [], task_id_1)
  157. task_id_2 = TaskId("id2")
  158. second_task = Task("name_2", {}, print, [], [], task_id_2)
  159. third_task_with_same_id_as_first_task = Task("name_is_not_1_anymore", {}, print, [], [], task_id_1)
  160. # No task at initialization
  161. assert len(_TaskManager._get_all()) == 0
  162. assert _TaskManager._get(task_id_1) is None
  163. assert _TaskManager._get(first_task) is None
  164. assert _TaskManager._get(task_id_2) is None
  165. assert _TaskManager._get(second_task) is None
  166. # Save one task. We expect to have only one task stored
  167. _TaskManager._set(first_task)
  168. assert len(_TaskManager._get_all()) == 1
  169. assert _TaskManager._get(task_id_1).id == first_task.id
  170. assert _TaskManager._get(first_task).id == first_task.id
  171. assert _TaskManager._get(task_id_2) is None
  172. assert _TaskManager._get(second_task) is None
  173. # Save a second task. Now, we expect to have a total of two tasks stored
  174. _TaskManager._set(second_task)
  175. assert len(_TaskManager._get_all()) == 2
  176. assert _TaskManager._get(task_id_1).id == first_task.id
  177. assert _TaskManager._get(first_task).id == first_task.id
  178. assert _TaskManager._get(task_id_2).id == second_task.id
  179. assert _TaskManager._get(second_task).id == second_task.id
  180. # We save the first task again. We expect nothing to change
  181. _TaskManager._set(first_task)
  182. assert len(_TaskManager._get_all()) == 2
  183. assert _TaskManager._get(task_id_1).id == first_task.id
  184. assert _TaskManager._get(first_task).id == first_task.id
  185. assert _TaskManager._get(task_id_2).id == second_task.id
  186. assert _TaskManager._get(second_task).id == second_task.id
  187. # We save a third task with same id as the first one.
  188. # We expect the first task to be updated
  189. _TaskManager._set(third_task_with_same_id_as_first_task)
  190. assert len(_TaskManager._get_all()) == 2
  191. assert _TaskManager._get(task_id_1).id == third_task_with_same_id_as_first_task.id
  192. assert _TaskManager._get(task_id_1).config_id == third_task_with_same_id_as_first_task.config_id
  193. assert _TaskManager._get(first_task).id == third_task_with_same_id_as_first_task.id
  194. assert _TaskManager._get(task_id_2).id == second_task.id
  195. assert _TaskManager._get(second_task).id == second_task.id
  196. def test_get_all_on_multiple_versions_environment():
  197. # Create 5 tasks with 2 versions each
  198. # Only version 1.0 has the task with config_id = "config_id_1"
  199. # Only version 2.0 has the task with config_id = "config_id_6"
  200. for version in range(1, 3):
  201. for i in range(5):
  202. _TaskManager._set(
  203. Task(
  204. f"config_id_{i+version}", {}, print, [], [], id=TaskId(f"id{i}_v{version}"), version=f"{version}.0"
  205. )
  206. )
  207. _VersionManager._set_experiment_version("1.0")
  208. assert len(_TaskManager._get_all()) == 5
  209. assert len(_TaskManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_1"}])) == 1
  210. assert len(_TaskManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_6"}])) == 0
  211. _VersionManager._set_experiment_version("2.0")
  212. assert len(_TaskManager._get_all()) == 5
  213. assert len(_TaskManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_1"}])) == 0
  214. assert len(_TaskManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_6"}])) == 1
  215. _VersionManager._set_development_version("1.0")
  216. assert len(_TaskManager._get_all()) == 5
  217. assert len(_TaskManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_1"}])) == 1
  218. assert len(_TaskManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_6"}])) == 0
  219. _VersionManager._set_development_version("2.0")
  220. assert len(_TaskManager._get_all()) == 5
  221. assert len(_TaskManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_1"}])) == 0
  222. assert len(_TaskManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_6"}])) == 1
  223. def test_ensure_conservation_of_order_of_data_nodes_on_task_creation():
  224. embedded_1 = Config.configure_data_node("dn_1", "in_memory", scope=Scope.SCENARIO)
  225. embedded_2 = Config.configure_data_node("dn_2", "in_memory", scope=Scope.SCENARIO)
  226. embedded_3 = Config.configure_data_node("a_dn_3", "in_memory", scope=Scope.SCENARIO)
  227. embedded_4 = Config.configure_data_node("dn_4", "in_memory", scope=Scope.SCENARIO)
  228. embedded_5 = Config.configure_data_node("dn_5", "in_memory", scope=Scope.SCENARIO)
  229. input = [embedded_1, embedded_2, embedded_3]
  230. output = [embedded_4, embedded_5]
  231. task_config_1 = Config.configure_task("name_1", print, input, output)
  232. task_config_2 = Config.configure_task("name_2", print, input, output)
  233. task_1, task_2 = _TaskManager._bulk_get_or_create([task_config_1, task_config_2])
  234. assert [i.config_id for i in task_1.input.values()] == [embedded_1.id, embedded_2.id, embedded_3.id]
  235. assert [o.config_id for o in task_1.output.values()] == [embedded_4.id, embedded_5.id]
  236. assert [i.config_id for i in task_2.input.values()] == [embedded_1.id, embedded_2.id, embedded_3.id]
  237. assert [o.config_id for o in task_2.output.values()] == [embedded_4.id, embedded_5.id]
  238. def test_delete_raise_exception():
  239. dn_input_config_1 = Config.configure_data_node(
  240. "my_input_1", "in_memory", scope=Scope.SCENARIO, default_data="testing"
  241. )
  242. dn_output_config_1 = Config.configure_data_node("my_output_1", "in_memory")
  243. task_config_1 = Config.configure_task("task_config_1", print, dn_input_config_1, dn_output_config_1)
  244. task_1 = _create_task_from_config(task_config_1)
  245. _TaskManager._delete(task_1.id)
  246. with pytest.raises(ModelNotFound):
  247. _TaskManager._delete(task_1.id)
  248. def test_hard_delete():
  249. dn_input_config_1 = Config.configure_data_node(
  250. "my_input_1", "in_memory", scope=Scope.SCENARIO, default_data="testing"
  251. )
  252. dn_output_config_1 = Config.configure_data_node("my_output_1", "in_memory")
  253. task_config_1 = Config.configure_task("task_config_1", print, dn_input_config_1, dn_output_config_1)
  254. task_1 = _create_task_from_config(task_config_1)
  255. assert len(_TaskManager._get_all()) == 1
  256. assert len(_DataManager._get_all()) == 2
  257. _TaskManager._hard_delete(task_1.id)
  258. assert len(_TaskManager._get_all()) == 0
  259. assert len(_DataManager._get_all()) == 2
  260. def test_is_submittable():
  261. task_manager = _TaskManagerFactory._build_manager()
  262. scenario_manager = _ScenarioManagerFactory._build_manager()
  263. assert len(_TaskManager._get_all()) == 0
  264. dn_config_1 = Config.configure_pickle_data_node("dn_1", default_data=10)
  265. dn_config_2 = Config.configure_pickle_data_node("dn_2", default_data=15)
  266. task_config = Config.configure_task("task", print, [dn_config_1, dn_config_2])
  267. scenario_config = Config.configure_scenario("scenario", [task_config])
  268. scenario = scenario_manager._create(scenario_config)
  269. task = scenario.tasks["task"]
  270. dn_1 = scenario.dn_1
  271. dn_2 = scenario.dn_2
  272. assert len(task_manager._get_all()) == 1
  273. assert len(scenario_manager._get_all()) == 1
  274. assert scenario.id not in SubmittableStatusCache._submittable_id_datanodes
  275. assert task.id not in SubmittableStatusCache._submittable_id_datanodes
  276. assert task_manager._is_submittable(task)
  277. assert task_manager._is_submittable(task.id)
  278. assert scenario_manager._is_submittable(scenario)
  279. assert scenario_manager._is_submittable(scenario.id)
  280. assert not task_manager._is_submittable("Task_temp")
  281. dn_1.edit_in_progress = True
  282. assert scenario.id in SubmittableStatusCache._submittable_id_datanodes
  283. assert task.id in SubmittableStatusCache._submittable_id_datanodes
  284. assert dn_1.id in SubmittableStatusCache._submittable_id_datanodes[scenario.id]
  285. assert dn_1.id in SubmittableStatusCache._submittable_id_datanodes[task.id]
  286. assert dn_1.id in SubmittableStatusCache._datanode_id_submittables
  287. assert scenario.id in SubmittableStatusCache._datanode_id_submittables[dn_1.id]
  288. assert task.id in SubmittableStatusCache._datanode_id_submittables[dn_1.id]
  289. assert (
  290. SubmittableStatusCache._submittable_id_datanodes[scenario.id][dn_1.id] == f"DataNode {dn_1.id} is being edited"
  291. )
  292. assert SubmittableStatusCache._submittable_id_datanodes[task.id][dn_1.id] == f"DataNode {dn_1.id} is being edited"
  293. assert not scenario_manager._is_submittable(scenario)
  294. assert not task_manager._is_submittable(task)
  295. assert not task_manager._is_submittable(task.id)
  296. dn_1.edit_in_progress = False
  297. assert scenario.id not in SubmittableStatusCache._submittable_id_datanodes
  298. assert task.id not in SubmittableStatusCache._submittable_id_datanodes
  299. assert dn_1.id not in SubmittableStatusCache._datanode_id_submittables
  300. assert scenario_manager._is_submittable(scenario)
  301. assert task_manager._is_submittable(task)
  302. assert task_manager._is_submittable(task.id)
  303. dn_1.last_edit_date = None
  304. dn_2.edit_in_progress = True
  305. assert scenario.id in SubmittableStatusCache._submittable_id_datanodes
  306. assert task.id in SubmittableStatusCache._submittable_id_datanodes
  307. assert dn_1.id in SubmittableStatusCache._submittable_id_datanodes[scenario.id]
  308. assert dn_1.id in SubmittableStatusCache._submittable_id_datanodes[task.id]
  309. assert dn_2.id in SubmittableStatusCache._submittable_id_datanodes[scenario.id]
  310. assert dn_2.id in SubmittableStatusCache._submittable_id_datanodes[task.id]
  311. assert dn_1.id in SubmittableStatusCache._datanode_id_submittables
  312. assert scenario.id in SubmittableStatusCache._datanode_id_submittables[dn_1.id]
  313. assert task.id in SubmittableStatusCache._datanode_id_submittables[dn_1.id]
  314. assert dn_2.id in SubmittableStatusCache._datanode_id_submittables
  315. assert scenario.id in SubmittableStatusCache._datanode_id_submittables[dn_2.id]
  316. assert task.id in SubmittableStatusCache._datanode_id_submittables[dn_2.id]
  317. assert (
  318. SubmittableStatusCache._submittable_id_datanodes[scenario.id][dn_1.id] == f"DataNode {dn_1.id} is not written"
  319. )
  320. assert SubmittableStatusCache._submittable_id_datanodes[task.id][dn_1.id] == f"DataNode {dn_1.id} is not written"
  321. assert (
  322. SubmittableStatusCache._submittable_id_datanodes[scenario.id][dn_2.id] == f"DataNode {dn_2.id} is being edited"
  323. )
  324. assert SubmittableStatusCache._submittable_id_datanodes[task.id][dn_2.id] == f"DataNode {dn_2.id} is being edited"
  325. assert not scenario_manager._is_submittable(scenario)
  326. assert not task_manager._is_submittable(task)
  327. assert not task_manager._is_submittable(task.id)
  328. dn_1.last_edit_date = datetime.now()
  329. assert scenario.id in SubmittableStatusCache._submittable_id_datanodes
  330. assert task.id in SubmittableStatusCache._submittable_id_datanodes
  331. assert dn_1.id not in SubmittableStatusCache._submittable_id_datanodes[scenario.id]
  332. assert dn_1.id not in SubmittableStatusCache._submittable_id_datanodes[task.id]
  333. assert dn_2.id in SubmittableStatusCache._submittable_id_datanodes[scenario.id]
  334. assert dn_2.id in SubmittableStatusCache._submittable_id_datanodes[task.id]
  335. assert dn_1.id not in SubmittableStatusCache._datanode_id_submittables
  336. assert dn_2.id in SubmittableStatusCache._datanode_id_submittables
  337. assert scenario.id in SubmittableStatusCache._datanode_id_submittables[dn_2.id]
  338. assert task.id in SubmittableStatusCache._datanode_id_submittables[dn_2.id]
  339. assert (
  340. SubmittableStatusCache._submittable_id_datanodes[scenario.id][dn_2.id] == f"DataNode {dn_2.id} is being edited"
  341. )
  342. assert SubmittableStatusCache._submittable_id_datanodes[task.id][dn_2.id] == f"DataNode {dn_2.id} is being edited"
  343. assert not scenario_manager._is_submittable(scenario)
  344. assert not task_manager._is_submittable(task)
  345. assert not task_manager._is_submittable(task.id)
  346. dn_2.edit_in_progress = False
  347. assert scenario.id not in SubmittableStatusCache._submittable_id_datanodes
  348. assert task.id not in SubmittableStatusCache._submittable_id_datanodes
  349. assert dn_2.id not in SubmittableStatusCache._submittable_id_datanodes[scenario.id]
  350. assert dn_2.id not in SubmittableStatusCache._submittable_id_datanodes[task.id]
  351. assert dn_2.id not in SubmittableStatusCache._datanode_id_submittables
  352. assert scenario_manager._is_submittable(scenario)
  353. assert task_manager._is_submittable(task)
  354. assert task_manager._is_submittable(task.id)
  355. def test_submit_task():
  356. data_node_1 = InMemoryDataNode("foo", Scope.SCENARIO, "s1")
  357. data_node_2 = InMemoryDataNode("bar", Scope.SCENARIO, "s2")
  358. task_1 = Task(
  359. "grault",
  360. {},
  361. print,
  362. [data_node_1],
  363. [data_node_2],
  364. TaskId("t1"),
  365. )
  366. class MockOrchestrator(_Orchestrator):
  367. submit_calls = []
  368. submit_ids = []
  369. def submit_task(self, task, callbacks=None, force=False, wait=False, timeout=None):
  370. submit_id = f"SUBMISSION_{str(uuid.uuid4())}"
  371. self.submit_calls.append(task)
  372. self.submit_ids.append(submit_id)
  373. return None
  374. with mock.patch("taipy.core.task._task_manager._TaskManager._orchestrator", new=MockOrchestrator):
  375. # Task does not exist, we expect an exception
  376. with pytest.raises(NonExistingTask):
  377. _TaskManager._submit(task_1)
  378. with pytest.raises(NonExistingTask):
  379. _TaskManager._submit(task_1.id)
  380. _TaskManager._set(task_1)
  381. _TaskManager._submit(task_1)
  382. call_ids = [call.id for call in MockOrchestrator.submit_calls]
  383. assert call_ids == [task_1.id]
  384. assert len(MockOrchestrator.submit_ids) == 1
  385. _TaskManager._submit(task_1)
  386. assert len(MockOrchestrator.submit_ids) == 2
  387. assert len(MockOrchestrator.submit_ids) == len(set(MockOrchestrator.submit_ids))
  388. _TaskManager._submit(task_1)
  389. assert len(MockOrchestrator.submit_ids) == 3
  390. assert len(MockOrchestrator.submit_ids) == len(set(MockOrchestrator.submit_ids))
  391. def my_print(a, b):
  392. print(a + b) # noqa: T201
  393. def test_submit_task_with_input_dn_wrong_file_path(caplog):
  394. csv_dn_cfg = Config.configure_csv_data_node("wrong_csv_file_path", default_path="wrong_path.csv")
  395. pickle_dn_cfg = Config.configure_pickle_data_node("wrong_pickle_file_path", default_path="wrong_path.pickle")
  396. parquet_dn_cfg = Config.configure_parquet_data_node("wrong_parquet_file_path", default_path="wrong_path.parquet")
  397. task_cfg = Config.configure_task("task", my_print, [csv_dn_cfg, pickle_dn_cfg], parquet_dn_cfg)
  398. task_manager = _TaskManagerFactory._build_manager()
  399. tasks = task_manager._bulk_get_or_create([task_cfg])
  400. task = tasks[0]
  401. taipy.submit(task)
  402. stdout = caplog.text
  403. expected_outputs = [
  404. f"{input_dn.id} cannot be read because it has never been written. Hint: The data node may refer to a wrong "
  405. f"path : {input_dn.path} "
  406. for input_dn in task.input.values()
  407. ]
  408. not_expected_outputs = [
  409. f"{input_dn.id} cannot be read because it has never been written. Hint: The data node may refer to a wrong "
  410. f"path : {input_dn.path} "
  411. for input_dn in task.output.values()
  412. ]
  413. assert all(expected_output in stdout for expected_output in expected_outputs)
  414. assert all(expected_output not in stdout for expected_output in not_expected_outputs)
  415. def test_submit_task_with_one_input_dn_wrong_file_path(caplog):
  416. csv_dn_cfg = Config.configure_csv_data_node("wrong_csv_file_path", default_path="wrong_path.csv")
  417. pickle_dn_cfg = Config.configure_pickle_data_node("pickle_file_path", default_data="value")
  418. parquet_dn_cfg = Config.configure_parquet_data_node("wrong_parquet_file_path", default_path="wrong_path.parquet")
  419. task_cfg = Config.configure_task("task", my_print, [csv_dn_cfg, pickle_dn_cfg], parquet_dn_cfg)
  420. task_manager = _TaskManagerFactory._build_manager()
  421. tasks = task_manager._bulk_get_or_create([task_cfg])
  422. task = tasks[0]
  423. taipy.submit(task)
  424. stdout = caplog.text
  425. expected_outputs = [
  426. f"{input_dn.id} cannot be read because it has never been written. Hint: The data node may refer to a wrong "
  427. f"path : {input_dn.path} "
  428. for input_dn in [task.input["wrong_csv_file_path"]]
  429. ]
  430. not_expected_outputs = [
  431. f"{input_dn.id} cannot be read because it has never been written. Hint: The data node may refer to a wrong "
  432. f"path : {input_dn.path} "
  433. for input_dn in [task.input["pickle_file_path"], task.output["wrong_parquet_file_path"]]
  434. ]
  435. assert all(expected_output in stdout for expected_output in expected_outputs)
  436. assert all(expected_output not in stdout for expected_output in not_expected_outputs)
  437. def test_get_tasks_by_config_id():
  438. dn_config = Config.configure_data_node("dn", scope=Scope.SCENARIO)
  439. task_config_1 = Config.configure_task("t1", print, dn_config)
  440. task_config_2 = Config.configure_task("t2", print, dn_config)
  441. task_config_3 = Config.configure_task("t3", print, dn_config)
  442. t_1_1 = _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_1")[0]
  443. t_1_2 = _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_2")[0]
  444. t_1_3 = _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_3")[0]
  445. assert len(_TaskManager._get_all()) == 3
  446. t_2_1 = _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_4")[0]
  447. t_2_2 = _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_5")[0]
  448. assert len(_TaskManager._get_all()) == 5
  449. t_3_1 = _TaskManager._bulk_get_or_create([task_config_3], scenario_id="scenario_6")[0]
  450. assert len(_TaskManager._get_all()) == 6
  451. t1_tasks = _TaskManager._get_by_config_id(task_config_1.id)
  452. assert len(t1_tasks) == 3
  453. assert sorted([t_1_1.id, t_1_2.id, t_1_3.id]) == sorted([task.id for task in t1_tasks])
  454. t2_tasks = _TaskManager._get_by_config_id(task_config_2.id)
  455. assert len(t2_tasks) == 2
  456. assert sorted([t_2_1.id, t_2_2.id]) == sorted([task.id for task in t2_tasks])
  457. t3_tasks = _TaskManager._get_by_config_id(task_config_3.id)
  458. assert len(t3_tasks) == 1
  459. assert sorted([t_3_1.id]) == sorted([task.id for task in t3_tasks])
  460. def test_get_scenarios_by_config_id_in_multiple_versions_environment():
  461. dn_config = Config.configure_data_node("dn", scope=Scope.SCENARIO)
  462. task_config_1 = Config.configure_task("t1", print, dn_config)
  463. task_config_2 = Config.configure_task("t2", print, dn_config)
  464. _VersionManager._set_experiment_version("1.0")
  465. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_1")[0]
  466. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_2")[0]
  467. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_3")[0]
  468. _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_4")[0]
  469. _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_5")[0]
  470. assert len(_TaskManager._get_by_config_id(task_config_1.id)) == 3
  471. assert len(_TaskManager._get_by_config_id(task_config_2.id)) == 2
  472. _VersionManager._set_experiment_version("2.0")
  473. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_1")[0]
  474. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_2")[0]
  475. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_3")[0]
  476. _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_4")[0]
  477. _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_5")[0]
  478. assert len(_TaskManager._get_by_config_id(task_config_1.id)) == 3
  479. assert len(_TaskManager._get_by_config_id(task_config_2.id)) == 2
  480. def _create_task_from_config(task_config, *args, **kwargs):
  481. return _TaskManager._bulk_get_or_create([task_config], *args, **kwargs)[0]