test_task_manager.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # Copyright 2021-2025 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import uuid
  12. from unittest import mock
  13. import pytest
  14. from taipy import Scope
  15. from taipy.common.config import Config
  16. from taipy.core import taipy
  17. from taipy.core._orchestrator._orchestrator import _Orchestrator
  18. from taipy.core._version._version_manager import _VersionManager
  19. from taipy.core.data._data_manager import _DataManager
  20. from taipy.core.data.in_memory import InMemoryDataNode
  21. from taipy.core.exceptions.exceptions import ModelNotFound, NonExistingTask
  22. from taipy.core.reason import EntityDoesNotExist
  23. from taipy.core.task._task_manager import _TaskManager
  24. from taipy.core.task._task_manager_factory import _TaskManagerFactory
  25. from taipy.core.task.task import Task
  26. from taipy.core.task.task_id import TaskId
  27. def test_create_and_save():
  28. input_configs = [Config.configure_data_node("my_input", "in_memory")]
  29. output_configs = Config.configure_data_node("my_output", "in_memory")
  30. task_config = Config.configure_task("foo", print, input_configs, output_configs)
  31. task = _create_task_from_config(task_config)
  32. assert task.id is not None
  33. assert task.config_id == "foo"
  34. assert len(task.input) == 1
  35. assert len(_DataManager._get_all()) == 2
  36. assert task.my_input.id is not None
  37. assert task.my_input.config_id == "my_input"
  38. assert task.my_output.id is not None
  39. assert task.my_output.config_id == "my_output"
  40. assert task.function == print
  41. assert task.parent_ids == set()
  42. task_retrieved_from_manager = _TaskManager._get(task.id)
  43. assert task_retrieved_from_manager.id == task.id
  44. assert task_retrieved_from_manager.config_id == task.config_id
  45. assert len(task_retrieved_from_manager.input) == len(task.input)
  46. assert task_retrieved_from_manager.my_input.id is not None
  47. assert task_retrieved_from_manager.my_input.config_id == task.my_input.config_id
  48. assert task_retrieved_from_manager.my_output.id is not None
  49. assert task_retrieved_from_manager.my_output.config_id == task.my_output.config_id
  50. assert task_retrieved_from_manager.function == task.function
  51. assert task_retrieved_from_manager.parent_ids == set()
  52. def test_do_not_recreate_existing_data_node():
  53. input_config = Config.configure_data_node("my_input", "in_memory", scope=Scope.SCENARIO)
  54. output_config = Config.configure_data_node("my_output", "in_memory", scope=Scope.SCENARIO)
  55. _DataManager._create(input_config, "scenario_id", "task_id")
  56. assert len(_DataManager._get_all()) == 1
  57. task_config = Config.configure_task("foo", print, input_config, output_config)
  58. _create_task_from_config(task_config, scenario_id="scenario_id")
  59. assert len(_DataManager._get_all()) == 2
  60. def test_assign_task_as_parent_of_datanode():
  61. dn_config_1 = Config.configure_data_node("dn_1", "in_memory", scope=Scope.SCENARIO)
  62. dn_config_2 = Config.configure_data_node("dn_2", "in_memory", scope=Scope.SCENARIO)
  63. dn_config_3 = Config.configure_data_node("dn_3", "in_memory", scope=Scope.SCENARIO)
  64. task_config_1 = Config.configure_task("task_1", print, dn_config_1, dn_config_2)
  65. task_config_2 = Config.configure_task("task_2", print, dn_config_2, dn_config_3)
  66. tasks = _TaskManager._bulk_get_or_create([task_config_1, task_config_2], "cycle_id", "scenario_id")
  67. assert len(_DataManager._get_all()) == 3
  68. assert len(_TaskManager._get_all()) == 2
  69. assert len(tasks) == 2
  70. dns = {dn.config_id: dn for dn in _DataManager._get_all()}
  71. assert dns["dn_1"].parent_ids == {tasks[0].id}
  72. assert dns["dn_2"].parent_ids == {tasks[0].id, tasks[1].id}
  73. assert dns["dn_3"].parent_ids == {tasks[1].id}
  74. def test_do_not_recreate_existing_task():
  75. input_config_scope_scenario = Config.configure_data_node("my_input_1", "in_memory", Scope.SCENARIO)
  76. output_config_scope_scenario = Config.configure_data_node("my_output_1", "in_memory", Scope.SCENARIO)
  77. task_config_1 = Config.configure_task("bar", print, input_config_scope_scenario, output_config_scope_scenario)
  78. # task_config_2 scope is Scenario
  79. task_1 = _create_task_from_config(task_config_1)
  80. assert len(_TaskManager._get_all()) == 1
  81. task_2 = _create_task_from_config(task_config_1) # Do not create. It already exists for None scenario
  82. assert len(_TaskManager._get_all()) == 1
  83. assert task_1.id == task_2.id
  84. task_3 = _create_task_from_config(task_config_1, None, None) # Do not create. It already exists for None scenario
  85. assert len(_TaskManager._get_all()) == 1
  86. assert task_1.id == task_2.id
  87. assert task_2.id == task_3.id
  88. task_4 = _create_task_from_config(task_config_1, None, "scenario_1") # Create even if sequence is the same.
  89. assert len(_TaskManager._get_all()) == 2
  90. assert task_1.id == task_2.id
  91. assert task_2.id == task_3.id
  92. assert task_3.id != task_4.id
  93. task_5 = _create_task_from_config(
  94. task_config_1, None, "scenario_1"
  95. ) # Do not create. It already exists for scenario_1
  96. assert len(_TaskManager._get_all()) == 2
  97. assert task_1.id == task_2.id
  98. assert task_2.id == task_3.id
  99. assert task_3.id != task_4.id
  100. assert task_4.id == task_5.id
  101. task_6 = _create_task_from_config(task_config_1, None, "scenario_2")
  102. assert len(_TaskManager._get_all()) == 3
  103. assert task_1.id == task_2.id
  104. assert task_2.id == task_3.id
  105. assert task_3.id != task_4.id
  106. assert task_4.id == task_5.id
  107. assert task_5.id != task_6.id
  108. assert task_3.id != task_6.id
  109. input_config_scope_cycle = Config.configure_data_node("my_input_2", "in_memory", Scope.CYCLE)
  110. output_config_scope_cycle = Config.configure_data_node("my_output_2", "in_memory", Scope.CYCLE)
  111. task_config_2 = Config.configure_task("xyz", print, input_config_scope_cycle, output_config_scope_cycle)
  112. # task_config_3 scope is Cycle
  113. task_7 = _create_task_from_config(task_config_2)
  114. assert len(_TaskManager._get_all()) == 4
  115. task_8 = _create_task_from_config(task_config_2) # Do not create. It already exists for None cycle
  116. assert len(_TaskManager._get_all()) == 4
  117. assert task_7.id == task_8.id
  118. task_9 = _create_task_from_config(task_config_2, None, None) # Do not create. It already exists for None cycle
  119. assert len(_TaskManager._get_all()) == 4
  120. assert task_7.id == task_8.id
  121. assert task_8.id == task_9.id
  122. task_10 = _create_task_from_config(
  123. task_config_2, None, "scenario"
  124. ) # Do not create. It already exists for None cycle
  125. assert len(_TaskManager._get_all()) == 4
  126. assert task_7.id == task_8.id
  127. assert task_8.id == task_9.id
  128. assert task_9.id == task_10.id
  129. task_11 = _create_task_from_config(
  130. task_config_2, None, "scenario"
  131. ) # Do not create. It already exists for None cycle
  132. assert len(_TaskManager._get_all()) == 4
  133. assert task_7.id == task_8.id
  134. assert task_8.id == task_9.id
  135. assert task_9.id == task_10.id
  136. assert task_10.id == task_11.id
  137. task_12 = _create_task_from_config(task_config_2, "cycle", None)
  138. assert len(_TaskManager._get_all()) == 5
  139. assert task_7.id == task_8.id
  140. assert task_8.id == task_9.id
  141. assert task_9.id == task_10.id
  142. assert task_10.id == task_11.id
  143. assert task_11.id != task_12.id
  144. task_13 = _create_task_from_config(task_config_2, "cycle", None)
  145. assert len(_TaskManager._get_all()) == 5
  146. assert task_7.id == task_8.id
  147. assert task_8.id == task_9.id
  148. assert task_9.id == task_10.id
  149. assert task_10.id == task_11.id
  150. assert task_11.id != task_12.id
  151. assert task_12.id == task_13.id
  152. def test_set_and_get_task():
  153. task_id_1 = TaskId("id1")
  154. first_task = Task("name_1", {}, print, [], [], task_id_1)
  155. task_id_2 = TaskId("id2")
  156. second_task = Task("name_2", {}, print, [], [], task_id_2)
  157. third_task_with_same_id_as_first_task = Task("name_is_not_1_anymore", {}, print, [], [], task_id_1)
  158. # No task at initialization
  159. assert len(_TaskManager._get_all()) == 0
  160. assert _TaskManager._get(task_id_1) is None
  161. assert _TaskManager._get(first_task) is None
  162. assert _TaskManager._get(task_id_2) is None
  163. assert _TaskManager._get(second_task) is None
  164. # Save one task. We expect to have only one task stored
  165. _TaskManager._repository._save(first_task)
  166. assert len(_TaskManager._get_all()) == 1
  167. assert _TaskManager._get(task_id_1).id == first_task.id
  168. assert _TaskManager._get(first_task).id == first_task.id
  169. assert _TaskManager._get(task_id_2) is None
  170. assert _TaskManager._get(second_task) is None
  171. # Save a second task. Now, we expect to have a total of two tasks stored
  172. _TaskManager._repository._save(second_task)
  173. assert len(_TaskManager._get_all()) == 2
  174. assert _TaskManager._get(task_id_1).id == first_task.id
  175. assert _TaskManager._get(first_task).id == first_task.id
  176. assert _TaskManager._get(task_id_2).id == second_task.id
  177. assert _TaskManager._get(second_task).id == second_task.id
  178. # We save the first task again. We expect nothing to change
  179. _TaskManager._update(first_task)
  180. assert len(_TaskManager._get_all()) == 2
  181. assert _TaskManager._get(task_id_1).id == first_task.id
  182. assert _TaskManager._get(first_task).id == first_task.id
  183. assert _TaskManager._get(task_id_2).id == second_task.id
  184. assert _TaskManager._get(second_task).id == second_task.id
  185. # We save a third task with same id as the first one.
  186. # We expect the first task to be updated
  187. _TaskManager._repository._save(third_task_with_same_id_as_first_task)
  188. assert len(_TaskManager._get_all()) == 2
  189. assert _TaskManager._get(task_id_1).id == third_task_with_same_id_as_first_task.id
  190. assert _TaskManager._get(task_id_1).config_id == third_task_with_same_id_as_first_task.config_id
  191. assert _TaskManager._get(first_task).id == third_task_with_same_id_as_first_task.id
  192. assert _TaskManager._get(task_id_2).id == second_task.id
  193. assert _TaskManager._get(second_task).id == second_task.id
  194. def test_get_all_on_multiple_versions_environment():
  195. # Create 5 tasks with 2 versions each
  196. # Only version 1.0 has the task with config_id = "config_id_1"
  197. # Only version 2.0 has the task with config_id = "config_id_6"
  198. for version in range(1, 3):
  199. for i in range(5):
  200. _TaskManager._repository._save(
  201. Task(
  202. f"config_id_{i+version}", {}, print, [], [], id=TaskId(f"id{i}_v{version}"), version=f"{version}.0"
  203. )
  204. )
  205. _VersionManager._set_experiment_version("1.0")
  206. assert len(_TaskManager._get_all()) == 5
  207. assert len(_TaskManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_1"}])) == 1
  208. assert len(_TaskManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_6"}])) == 0
  209. _VersionManager._set_experiment_version("2.0")
  210. assert len(_TaskManager._get_all()) == 5
  211. assert len(_TaskManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_1"}])) == 0
  212. assert len(_TaskManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_6"}])) == 1
  213. _VersionManager._set_development_version("1.0")
  214. assert len(_TaskManager._get_all()) == 5
  215. assert len(_TaskManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_1"}])) == 1
  216. assert len(_TaskManager._get_all_by(filters=[{"version": "1.0", "config_id": "config_id_6"}])) == 0
  217. _VersionManager._set_development_version("2.0")
  218. assert len(_TaskManager._get_all()) == 5
  219. assert len(_TaskManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_1"}])) == 0
  220. assert len(_TaskManager._get_all_by(filters=[{"version": "2.0", "config_id": "config_id_6"}])) == 1
  221. def test_ensure_conservation_of_order_of_data_nodes_on_task_creation():
  222. embedded_1 = Config.configure_data_node("dn_1", "in_memory", scope=Scope.SCENARIO)
  223. embedded_2 = Config.configure_data_node("dn_2", "in_memory", scope=Scope.SCENARIO)
  224. embedded_3 = Config.configure_data_node("a_dn_3", "in_memory", scope=Scope.SCENARIO)
  225. embedded_4 = Config.configure_data_node("dn_4", "in_memory", scope=Scope.SCENARIO)
  226. embedded_5 = Config.configure_data_node("dn_5", "in_memory", scope=Scope.SCENARIO)
  227. input = [embedded_1, embedded_2, embedded_3]
  228. output = [embedded_4, embedded_5]
  229. task_config_1 = Config.configure_task("name_1", print, input, output)
  230. task_config_2 = Config.configure_task("name_2", print, input, output)
  231. task_1, task_2 = _TaskManager._bulk_get_or_create([task_config_1, task_config_2])
  232. assert [i.config_id for i in task_1.input.values()] == [embedded_1.id, embedded_2.id, embedded_3.id]
  233. assert [o.config_id for o in task_1.output.values()] == [embedded_4.id, embedded_5.id]
  234. assert [i.config_id for i in task_2.input.values()] == [embedded_1.id, embedded_2.id, embedded_3.id]
  235. assert [o.config_id for o in task_2.output.values()] == [embedded_4.id, embedded_5.id]
  236. def test_delete_raise_exception():
  237. dn_input_config_1 = Config.configure_data_node(
  238. "my_input_1", "in_memory", scope=Scope.SCENARIO, default_data="testing"
  239. )
  240. dn_output_config_1 = Config.configure_data_node("my_output_1", "in_memory")
  241. task_config_1 = Config.configure_task("task_config_1", print, dn_input_config_1, dn_output_config_1)
  242. task_1 = _create_task_from_config(task_config_1)
  243. _TaskManager._delete(task_1.id)
  244. with pytest.raises(ModelNotFound):
  245. _TaskManager._delete(task_1.id)
  246. def test_hard_delete():
  247. dn_input_config_1 = Config.configure_data_node(
  248. "my_input_1", "in_memory", scope=Scope.SCENARIO, default_data="testing"
  249. )
  250. dn_output_config_1 = Config.configure_data_node("my_output_1", "in_memory")
  251. task_config_1 = Config.configure_task("task_config_1", print, dn_input_config_1, dn_output_config_1)
  252. task_1 = _create_task_from_config(task_config_1)
  253. assert len(_TaskManager._get_all()) == 1
  254. assert len(_DataManager._get_all()) == 2
  255. _TaskManager._hard_delete(task_1.id)
  256. assert len(_TaskManager._get_all()) == 0
  257. assert len(_DataManager._get_all()) == 2
  258. def test_is_submittable():
  259. assert len(_TaskManager._get_all()) == 0
  260. dn_config = Config.configure_in_memory_data_node("dn", 10)
  261. task_config = Config.configure_task("task", print, [dn_config])
  262. task = _TaskManager._bulk_get_or_create([task_config])[0]
  263. rc = _TaskManager._is_submittable("some_task")
  264. assert not rc
  265. assert "Entity 'some_task' does not exist in the repository" in rc.reasons
  266. assert len(_TaskManager._get_all()) == 1
  267. assert _TaskManager._is_submittable(task)
  268. assert _TaskManager._is_submittable(task.id)
  269. assert not _TaskManager._is_submittable("Task_temp")
  270. task.input["dn"].edit_in_progress = True
  271. assert not _TaskManager._is_submittable(task)
  272. assert not _TaskManager._is_submittable(task.id)
  273. task.input["dn"].edit_in_progress = False
  274. assert _TaskManager._is_submittable(task)
  275. assert _TaskManager._is_submittable(task.id)
  276. def test_submit_task():
  277. data_node_1 = InMemoryDataNode("foo", Scope.SCENARIO, "s1")
  278. _DataManager._repository._save(data_node_1)
  279. data_node_2 = InMemoryDataNode("bar", Scope.SCENARIO, "s2")
  280. _DataManager._repository._save(data_node_2)
  281. task_1 = Task(
  282. "grault",
  283. {},
  284. print,
  285. [data_node_1],
  286. [data_node_2],
  287. TaskId("t1"),
  288. )
  289. class MockOrchestrator(_Orchestrator):
  290. submit_calls = []
  291. submit_ids = []
  292. def submit_task(self, task, callbacks=None, force=False, wait=False, timeout=None):
  293. submit_id = f"SUBMISSION_{str(uuid.uuid4())}"
  294. self.submit_calls.append(task)
  295. self.submit_ids.append(submit_id)
  296. return None
  297. with mock.patch("taipy.core.task._task_manager._TaskManager._orchestrator", new=MockOrchestrator):
  298. # Task does not exist, we expect an exception
  299. with pytest.raises(NonExistingTask):
  300. _TaskManager._submit(task_1)
  301. with pytest.raises(NonExistingTask):
  302. _TaskManager._submit(task_1.id)
  303. _TaskManager._repository._save(task_1)
  304. _TaskManager._submit(task_1)
  305. call_ids = [call.id for call in MockOrchestrator.submit_calls]
  306. assert call_ids == [task_1.id]
  307. assert len(MockOrchestrator.submit_ids) == 1
  308. _TaskManager._submit(task_1)
  309. assert len(MockOrchestrator.submit_ids) == 2
  310. assert len(MockOrchestrator.submit_ids) == len(set(MockOrchestrator.submit_ids))
  311. _TaskManager._submit(task_1)
  312. assert len(MockOrchestrator.submit_ids) == 3
  313. assert len(MockOrchestrator.submit_ids) == len(set(MockOrchestrator.submit_ids))
  314. def my_print(a, b):
  315. print(a + b) # noqa: T201
  316. def test_submit_task_with_input_dn_wrong_file_path(caplog):
  317. csv_dn_cfg = Config.configure_csv_data_node("wrong_csv_file_path", default_path="wrong_path.csv")
  318. pickle_dn_cfg = Config.configure_pickle_data_node("wrong_pickle_file_path", default_path="wrong_path.pickle")
  319. parquet_dn_cfg = Config.configure_parquet_data_node("wrong_parquet_file_path", default_path="wrong_path.parquet")
  320. task_cfg = Config.configure_task("task", my_print, [csv_dn_cfg, pickle_dn_cfg], parquet_dn_cfg)
  321. task_manager = _TaskManagerFactory._build_manager()
  322. tasks = task_manager._bulk_get_or_create([task_cfg])
  323. task = tasks[0]
  324. taipy.submit(task)
  325. stdout = caplog.text
  326. expected_outputs = [
  327. f"{input_dn.id} cannot be read because it has never been written. Hint: The data node may refer to a wrong "
  328. f"path : {input_dn.path} "
  329. for input_dn in task.input.values()
  330. ]
  331. not_expected_outputs = [
  332. f"{input_dn.id} cannot be read because it has never been written. Hint: The data node may refer to a wrong "
  333. f"path : {input_dn.path} "
  334. for input_dn in task.output.values()
  335. ]
  336. assert all(expected_output in stdout for expected_output in expected_outputs)
  337. assert all(expected_output not in stdout for expected_output in not_expected_outputs)
  338. def test_submit_task_with_one_input_dn_wrong_file_path(caplog):
  339. csv_dn_cfg = Config.configure_csv_data_node("wrong_csv_file_path", default_path="wrong_path.csv")
  340. pickle_dn_cfg = Config.configure_pickle_data_node("pickle_file_path", default_data="value")
  341. parquet_dn_cfg = Config.configure_parquet_data_node("wrong_parquet_file_path", default_path="wrong_path.parquet")
  342. task_cfg = Config.configure_task("task", my_print, [csv_dn_cfg, pickle_dn_cfg], parquet_dn_cfg)
  343. task_manager = _TaskManagerFactory._build_manager()
  344. tasks = task_manager._bulk_get_or_create([task_cfg])
  345. task = tasks[0]
  346. taipy.submit(task)
  347. stdout = caplog.text
  348. expected_outputs = [
  349. f"{input_dn.id} cannot be read because it has never been written. Hint: The data node may refer to a wrong "
  350. f"path : {input_dn.path} "
  351. for input_dn in [task.input["wrong_csv_file_path"]]
  352. ]
  353. not_expected_outputs = [
  354. f"{input_dn.id} cannot be read because it has never been written. Hint: The data node may refer to a wrong "
  355. f"path : {input_dn.path} "
  356. for input_dn in [task.input["pickle_file_path"], task.output["wrong_parquet_file_path"]]
  357. ]
  358. assert all(expected_output in stdout for expected_output in expected_outputs)
  359. assert all(expected_output not in stdout for expected_output in not_expected_outputs)
  360. def test_get_tasks_by_config_id():
  361. dn_config = Config.configure_data_node("dn", scope=Scope.SCENARIO)
  362. task_config_1 = Config.configure_task("t1", print, dn_config)
  363. task_config_2 = Config.configure_task("t2", print, dn_config)
  364. task_config_3 = Config.configure_task("t3", print, dn_config)
  365. t_1_1 = _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_1")[0]
  366. t_1_2 = _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_2")[0]
  367. t_1_3 = _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_3")[0]
  368. assert len(_TaskManager._get_all()) == 3
  369. t_2_1 = _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_4")[0]
  370. t_2_2 = _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_5")[0]
  371. assert len(_TaskManager._get_all()) == 5
  372. t_3_1 = _TaskManager._bulk_get_or_create([task_config_3], scenario_id="scenario_6")[0]
  373. assert len(_TaskManager._get_all()) == 6
  374. t1_tasks = _TaskManager._get_by_config_id(task_config_1.id)
  375. assert len(t1_tasks) == 3
  376. assert sorted([t_1_1.id, t_1_2.id, t_1_3.id]) == sorted([task.id for task in t1_tasks])
  377. t2_tasks = _TaskManager._get_by_config_id(task_config_2.id)
  378. assert len(t2_tasks) == 2
  379. assert sorted([t_2_1.id, t_2_2.id]) == sorted([task.id for task in t2_tasks])
  380. t3_tasks = _TaskManager._get_by_config_id(task_config_3.id)
  381. assert len(t3_tasks) == 1
  382. assert sorted([t_3_1.id]) == sorted([task.id for task in t3_tasks])
  383. def test_get_scenarios_by_config_id_in_multiple_versions_environment():
  384. dn_config = Config.configure_data_node("dn", scope=Scope.SCENARIO)
  385. task_config_1 = Config.configure_task("t1", print, dn_config)
  386. task_config_2 = Config.configure_task("t2", print, dn_config)
  387. _VersionManager._set_experiment_version("1.0")
  388. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_1")[0]
  389. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_2")[0]
  390. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_3")[0]
  391. _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_4")[0]
  392. _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_5")[0]
  393. assert len(_TaskManager._get_by_config_id(task_config_1.id)) == 3
  394. assert len(_TaskManager._get_by_config_id(task_config_2.id)) == 2
  395. _VersionManager._set_experiment_version("2.0")
  396. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_1")[0]
  397. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_2")[0]
  398. _TaskManager._bulk_get_or_create([task_config_1], scenario_id="scenario_3")[0]
  399. _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_4")[0]
  400. _TaskManager._bulk_get_or_create([task_config_2], scenario_id="scenario_5")[0]
  401. assert len(_TaskManager._get_by_config_id(task_config_1.id)) == 3
  402. assert len(_TaskManager._get_by_config_id(task_config_2.id)) == 2
  403. def _create_task_from_config(task_config, *args, **kwargs):
  404. return _TaskManager._bulk_get_or_create([task_config], *args, **kwargs)[0]
  405. def test_can_duplicate():
  406. dn_config = Config.configure_pickle_data_node("dn", scope=Scope.SCENARIO)
  407. task_config = Config.configure_task("task_1", print, [dn_config])
  408. task = _TaskManager._bulk_get_or_create([task_config])[0]
  409. reasons = _TaskManager._can_duplicate(task.id)
  410. assert bool(reasons)
  411. assert reasons._reasons == {}
  412. reasons = _TaskManager._can_duplicate(task)
  413. assert bool(reasons)
  414. assert reasons._reasons == {}
  415. reasons = _TaskManager._can_duplicate("1")
  416. assert not bool(reasons)
  417. assert reasons._reasons["1"] == {EntityDoesNotExist("1")}
  418. assert str(list(reasons._reasons["1"])[0]) == "Entity '1' does not exist in the repository"