test_scenario_config.py 18 KB


  1. # Copyright 2021-2025 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. from unittest import mock
  13. import pytest
  14. from taipy.common.config import Config
  15. from taipy.common.config.common.frequency import Frequency
  16. from tests.core.utils.named_temporary_file import NamedTemporaryFile
  17. def my_func():
  18. pass
  19. def _configure_scenario_in_toml():
  20. return NamedTemporaryFile(
  21. content="""
  22. [TAIPY]
  23. [TASK.task1]
  24. inputs = []
  25. outputs = []
  26. [TASK.task2]
  27. inputs = []
  28. outputs = []
  29. [SCENARIO.scenarios1]
  30. tasks = [ "task1:SECTION", "task2:SECTION"]
  31. """
  32. )
  33. def _check_tasks_instance(task_id, scenario_id):
  34. """Check if the task instance in the task config correctly points to the Config._applied_config,
  35. not the Config._python_config or the Config._file_config
  36. """
  37. task_config_applied_instance = Config.tasks[task_id]
  38. task_config_instance_via_scenario = None
  39. for task in Config.scenarios[scenario_id].tasks:
  40. if task.id == task_id:
  41. task_config_instance_via_scenario = task
  42. task_config_python_instance = None
  43. if Config._python_config._sections.get("TASK", None):
  44. task_config_python_instance = Config._python_config._sections["TASK"][task_id]
  45. task_config_file_instance = None
  46. if Config._file_config._sections.get("TASK", None):
  47. task_config_file_instance = Config._file_config._sections["TASK"][task_id]
  48. assert task_config_python_instance is not task_config_applied_instance
  49. assert task_config_python_instance is not task_config_instance_via_scenario
  50. assert task_config_file_instance is not task_config_applied_instance
  51. assert task_config_file_instance is not task_config_instance_via_scenario
  52. assert task_config_instance_via_scenario is task_config_applied_instance
  53. def test_task_instance_when_configure_scenario_in_python():
  54. task1_config = Config.configure_task("task1", [])
  55. task2_config = Config.configure_task("task2", print)
  56. Config.configure_scenario("scenarios1", [task1_config, task2_config])
  57. _check_tasks_instance("task1", "scenarios1")
  58. _check_tasks_instance("task2", "scenarios1")
  59. def test_task_instance_when_configure_scenario_by_loading_toml():
  60. toml_config = _configure_scenario_in_toml()
  61. Config.load(toml_config.filename)
  62. _check_tasks_instance("task1", "scenarios1")
  63. _check_tasks_instance("task2", "scenarios1")
  64. def test_task_instance_when_configure_scenario_by_overriding_toml():
  65. toml_config = _configure_scenario_in_toml()
  66. Config.override(toml_config.filename)
  67. _check_tasks_instance("task1", "scenarios1")
  68. _check_tasks_instance("task2", "scenarios1")
  69. def test_scenario_creation():
  70. dn_config_1 = Config.configure_data_node("dn1")
  71. dn_config_2 = Config.configure_data_node("dn2")
  72. dn_config_3 = Config.configure_data_node("dn3")
  73. dn_config_4 = Config.configure_data_node("dn4")
  74. task_config_1 = Config.configure_task("task1", sum, [dn_config_1, dn_config_2], dn_config_3)
  75. task_config_2 = Config.configure_task("task2", print, dn_config_3)
  76. scenario_cfg = Config.configure_scenario(
  77. "scenarios1",
  78. [task_config_1, task_config_2],
  79. [dn_config_4],
  80. comparators={"dn_cfg": [my_func]},
  81. sequences={"sequence": []},
  82. )
  83. assert list(Config.scenarios.keys()) == ["default", scenario_cfg.id]
  84. scenario2 = Config.configure_scenario("scenarios2", [task_config_1], frequency=Frequency.MONTHLY)
  85. assert list(Config.scenarios.keys()) == ["default", scenario_cfg.id, scenario2.id]
  86. def test_datanode_config_ranks():
  87. dn_config_1 = Config.configure_data_node("dn1")
  88. dn_config_2 = Config.configure_data_node("dn2")
  89. dn_config_3 = Config.configure_data_node("dn3")
  90. dn_config_4 = Config.configure_data_node("dn4")
  91. dn_config_5 = Config.configure_data_node("dn5")
  92. dn_config_6 = Config.configure_data_node("dn6")
  93. task_config_1 = Config.configure_task("task1", sum, dn_config_1, dn_config_2)
  94. task_config_2 = Config.configure_task("task2", sum, dn_config_2, dn_config_3)
  95. task_config_3 = Config.configure_task("task3", sum, [dn_config_1, dn_config_2], dn_config_3)
  96. task_config_4 = Config.configure_task("task4", sum, dn_config_3, [dn_config_4, dn_config_5])
  97. task_config_5 = Config.configure_task("task5", sum, dn_config_5, dn_config_6)
  98. # s1 additional: dn3
  99. # s1 dag: dn1 -> dn2
  100. Config.configure_scenario("s1", [task_config_1],[dn_config_3])
  101. # s2 additional: dn4
  102. # s2 dag: dn2 -> dn3
  103. Config.configure_scenario("s2", [task_config_2],[dn_config_4])
  104. # s3 additional: None
  105. # s3 dag: dn1 -> dn2 -> dn3
  106. Config.configure_scenario("s3", [task_config_1, task_config_2])
  107. # s4 additional: None
  108. # s4 dag: dn1 -- --> dn4
  109. # \ /
  110. # |----> dn3 ---|
  111. # / \
  112. # dn2 -- --> dn5 ---> dn6
  113. Config.configure_scenario("s4", [task_config_3, task_config_4, task_config_5])
  114. assert len(dn_config_1._ranks) == 3
  115. assert dn_config_1._ranks["s1"] == 1
  116. assert dn_config_1._ranks["s3"] == 1
  117. assert dn_config_1._ranks["s4"] == 1
  118. assert len(dn_config_2._ranks) == 4
  119. assert dn_config_2._ranks["s1"] == 2
  120. assert dn_config_2._ranks["s2"] == 1
  121. assert dn_config_2._ranks["s3"] == 2
  122. assert dn_config_2._ranks["s4"] == 1
  123. assert len(dn_config_3._ranks) == 4
  124. assert dn_config_3._ranks["s1"] == 0
  125. assert dn_config_3._ranks["s2"] == 2
  126. assert dn_config_3._ranks["s3"] == 3
  127. assert dn_config_3._ranks["s4"] == 2
  128. assert len(dn_config_4._ranks) == 2
  129. assert dn_config_4._ranks["s2"] == 0
  130. assert dn_config_4._ranks["s4"] == 3
  131. assert len(dn_config_5._ranks) == 1
  132. assert dn_config_5._ranks["s4"] == 3
  133. assert len(dn_config_6._ranks) == 1
  134. assert dn_config_6._ranks["s4"] == 4
  135. def test_scenario_count():
  136. task_config_1 = Config.configure_task("task1", my_func)
  137. task_config_2 = Config.configure_task("task2", print)
  138. Config.configure_scenario("scenarios1", [task_config_1, task_config_2])
  139. assert len(Config.scenarios) == 2
  140. Config.configure_scenario("scenarios2", [task_config_1])
  141. assert len(Config.scenarios) == 3
  142. Config.configure_scenario("scenarios3", [task_config_2])
  143. assert len(Config.scenarios) == 4
  144. def test_scenario_getitem():
  145. dn_config_1 = Config.configure_data_node("dn1")
  146. dn_config_2 = Config.configure_data_node("dn2")
  147. dn_config_3 = Config.configure_data_node("dn3")
  148. dn_config_4 = Config.configure_data_node("dn4")
  149. task_config_1 = Config.configure_task("task1", sum, [dn_config_1, dn_config_2], dn_config_3)
  150. task_config_2 = Config.configure_task("task2", print, dn_config_3)
  151. scenario_id = "scenarios1"
  152. scenario = Config.configure_scenario(scenario_id, [task_config_1, task_config_2], [dn_config_4])
  153. assert Config.scenarios[scenario_id].id == scenario.id
  154. assert Config.scenarios[scenario_id].task_configs == scenario.task_configs
  155. assert Config.scenarios[scenario_id].tasks == scenario.tasks
  156. assert Config.scenarios[scenario_id].task_configs == scenario.tasks
  157. assert Config.scenarios[scenario_id].additional_data_node_configs == scenario.additional_data_node_configs
  158. assert Config.scenarios[scenario_id].additional_data_nodes == scenario.additional_data_nodes
  159. assert Config.scenarios[scenario_id].additional_data_node_configs == scenario.additional_data_nodes
  160. assert Config.scenarios[scenario_id].data_node_configs == scenario.data_node_configs
  161. assert Config.scenarios[scenario_id].data_nodes == scenario.data_nodes
  162. assert Config.scenarios[scenario_id].data_node_configs == scenario.data_nodes
  163. assert scenario.tasks == [task_config_1, task_config_2]
  164. assert scenario.additional_data_node_configs == [dn_config_4]
  165. assert set(scenario.data_nodes) == {dn_config_4, dn_config_1, dn_config_2, dn_config_3}
  166. assert Config.scenarios[scenario_id].properties == scenario.properties
  167. def test_scenario_creation_no_duplication():
  168. task_config_1 = Config.configure_task("task1", my_func)
  169. task_config_2 = Config.configure_task("task2", print)
  170. dn_config = Config.configure_data_node("dn")
  171. Config.configure_scenario("scenarios1", [task_config_1, task_config_2], [dn_config])
  172. assert len(Config.scenarios) == 2
  173. Config.configure_scenario("scenarios1", [task_config_1, task_config_2], [dn_config])
  174. assert len(Config.scenarios) == 2
  175. def test_scenario_get_set_and_remove_comparators():
  176. task_config_1 = Config.configure_task("task1", my_func)
  177. task_config_2 = Config.configure_task("task2", print)
  178. dn_config_1 = "dn_config_1"
  179. scenario_config_1 = Config.configure_scenario(
  180. "scenarios1", [task_config_1, task_config_2], comparators={dn_config_1: my_func}
  181. )
  182. assert scenario_config_1.comparators is not None
  183. assert scenario_config_1.comparators[dn_config_1] == [my_func]
  184. assert len(scenario_config_1.comparators.keys()) == 1
  185. dn_config_2 = "dn_config_2"
  186. scenario_config_1.add_comparator(dn_config_2, my_func)
  187. assert len(scenario_config_1.comparators.keys()) == 2
  188. scenario_config_1.delete_comparator(dn_config_1)
  189. assert len(scenario_config_1.comparators.keys()) == 1
  190. scenario_config_1.delete_comparator(dn_config_2)
  191. assert len(scenario_config_1.comparators.keys()) == 0
  192. scenario_config_2 = Config.configure_scenario("scenarios2", [task_config_1, task_config_2])
  193. assert scenario_config_2.comparators is not None
  194. scenario_config_2.add_comparator(dn_config_1, my_func)
  195. assert len(scenario_config_2.comparators.keys()) == 1
  196. scenario_config_2.delete_comparator("dn_config_3")
  197. def test_scenario_config_with_env_variable_value():
  198. task_config_1 = Config.configure_task("task1", my_func)
  199. task_config_2 = Config.configure_task("task2", print)
  200. with mock.patch.dict(os.environ, {"FOO": "bar"}):
  201. Config.configure_scenario("scenario_name", [task_config_1, task_config_2], prop="ENV[FOO]")
  202. assert Config.scenarios["scenario_name"].prop == "bar"
  203. assert Config.scenarios["scenario_name"].properties["prop"] == "bar"
  204. assert Config.scenarios["scenario_name"]._properties["prop"] == "ENV[FOO]"
  205. def test_clean_config():
  206. task1_config = Config.configure_task("task1", print, [], [])
  207. task2_config = Config.configure_task("task2", print, [], [])
  208. scenario1_config = Config.configure_scenario(
  209. "id1",
  210. [task1_config, task2_config],
  211. [],
  212. Frequency.YEARLY,
  213. {"foo": "bar"},
  214. prop="foo",
  215. sequences={"sequence_1": []},
  216. )
  217. scenario2_config = Config.configure_scenario(
  218. "id2",
  219. [task2_config, task1_config],
  220. [],
  221. Frequency.MONTHLY,
  222. {"foz": "baz"},
  223. prop="bar",
  224. sequences={"sequence_2": []},
  225. )
  226. assert Config.scenarios["id1"] is scenario1_config
  227. assert Config.scenarios["id2"] is scenario2_config
  228. scenario1_config._clean()
  229. scenario2_config._clean()
  230. # Check if the instance before and after _clean() is the same
  231. assert Config.scenarios["id1"] is scenario1_config
  232. assert Config.scenarios["id2"] is scenario2_config
  233. assert scenario1_config.id == "id1"
  234. assert scenario2_config.id == "id2"
  235. assert scenario1_config.tasks == scenario1_config.task_configs == []
  236. assert scenario1_config.additional_data_nodes == scenario1_config.additional_data_node_configs == []
  237. assert scenario1_config.data_nodes == scenario1_config.data_node_configs == []
  238. assert scenario1_config.sequences == scenario1_config.sequences == {}
  239. assert scenario1_config.frequency is scenario1_config.frequency is None
  240. assert scenario1_config.comparators == scenario1_config.comparators == {}
  241. assert scenario1_config.properties == scenario1_config.properties == {}
  242. assert scenario2_config.tasks == scenario2_config.task_configs == []
  243. assert scenario2_config.additional_data_nodes == scenario2_config.additional_data_node_configs == []
  244. assert scenario2_config.data_nodes == scenario2_config.data_node_configs == []
  245. assert scenario2_config.sequences == scenario1_config.sequences == {}
  246. assert scenario2_config.frequency is scenario2_config.frequency is None
  247. assert scenario2_config.comparators == scenario2_config.comparators == {}
  248. assert scenario2_config.properties == scenario2_config.properties == {}
  249. def test_add_sequence():
  250. task1_config = Config.configure_task("task1", print, [], [])
  251. task2_config = Config.configure_task("task2", print, [], [])
  252. task3_config = Config.configure_task("task3", print, [], [])
  253. task4_config = Config.configure_task("task4", print, [], [])
  254. scenario_config = Config.configure_scenario(
  255. "id", [task1_config, task2_config, task3_config, task4_config], [], Frequency.YEARLY, prop="foo"
  256. )
  257. assert Config.scenarios["id"] is scenario_config
  258. assert scenario_config.id == "id"
  259. assert (
  260. scenario_config.tasks
  261. == scenario_config.task_configs
  262. == [task1_config, task2_config, task3_config, task4_config]
  263. )
  264. assert scenario_config.additional_data_nodes == scenario_config.additional_data_node_configs == []
  265. assert scenario_config.data_nodes == scenario_config.data_node_configs == []
  266. assert scenario_config.frequency is scenario_config.frequency == Frequency.YEARLY
  267. assert scenario_config.comparators == scenario_config.comparators == {}
  268. assert scenario_config.properties == {"prop": "foo"}
  269. scenario_config.add_sequences(
  270. {
  271. "sequence1": [task1_config],
  272. "sequence2": [task2_config, task3_config],
  273. "sequence3": [task1_config, task2_config, task4_config],
  274. }
  275. )
  276. assert len(scenario_config.sequences) == 3
  277. assert scenario_config.sequences["sequence1"] == [task1_config]
  278. assert scenario_config.sequences["sequence2"] == [task2_config, task3_config]
  279. assert scenario_config.sequences["sequence3"] == [task1_config, task2_config, task4_config]
  280. scenario_config.remove_sequences("sequence1")
  281. assert len(scenario_config.sequences) == 2
  282. scenario_config.remove_sequences(["sequence2", "sequence3"])
  283. assert len(scenario_config.sequences) == 0
  284. @pytest.mark.skip(reason="Generates a png that must be visually verified.")
  285. def test_draw_1():
  286. dn_config_1 = Config.configure_data_node("dn1")
  287. dn_config_2 = Config.configure_data_node("dn2")
  288. dn_config_3 = Config.configure_data_node("dn3")
  289. dn_config_4 = Config.configure_data_node("dn4")
  290. dn_config_5 = Config.configure_data_node("dn5")
  291. task_config_1 = Config.configure_task("task1", sum, input=[dn_config_1, dn_config_2], output=dn_config_3)
  292. task_config_2 = Config.configure_task("task2", sum, input=[dn_config_1, dn_config_3], output=dn_config_4)
  293. task_config_3 = Config.configure_task("task3", print, input=dn_config_4)
  294. scenario_cfg = Config.configure_scenario(
  295. "scenario1",
  296. [task_config_1, task_config_2, task_config_3],
  297. [dn_config_5],
  298. )
  299. scenario_cfg.draw()
  300. @pytest.mark.skip(reason="Generates a png that must be visually verified.")
  301. def test_draw_2():
  302. data_node_1 = Config.configure_data_node("s1")
  303. data_node_2 = Config.configure_data_node("s2")
  304. data_node_4 = Config.configure_data_node("s4")
  305. data_node_5 = Config.configure_data_node("s5")
  306. data_node_6 = Config.configure_data_node("s6")
  307. data_node_7 = Config.configure_data_node("s7")
  308. task_1 = Config.configure_task("t1", print, [data_node_1, data_node_2], [data_node_4])
  309. task_2 = Config.configure_task("t2", print, None, [data_node_5])
  310. task_3 = Config.configure_task("t3", print, [data_node_5, data_node_4], [data_node_6])
  311. task_4 = Config.configure_task("t4", print, [data_node_4], [data_node_7])
  312. scenario_cfg = Config.configure_scenario("scenario1", [task_4, task_2, task_1, task_3])
  313. # 6 | t2 _____
  314. # 5 | \
  315. # 4 | s5 _________________ t3 _______ s6
  316. # 3 | s1 __ _ s4 _____/
  317. # 2 | \ _ t1 ____/ \_ t4 _______ s7
  318. # 1 | /
  319. # 0 | s2 --
  320. # |________________________________________________
  321. # 0 1 2 3 4
  322. scenario_cfg.draw("draw_2")
  323. @pytest.mark.skip(reason="Generates a png that must be visually verified.")
  324. def test_draw_3():
  325. data_node_1 = Config.configure_data_node("s1")
  326. data_node_2 = Config.configure_data_node("s2")
  327. data_node_3 = Config.configure_data_node("s3")
  328. data_node_4 = Config.configure_data_node("s4")
  329. data_node_5 = Config.configure_data_node("s5")
  330. data_node_6 = Config.configure_data_node("s6")
  331. data_node_7 = Config.configure_data_node("s7")
  332. task_1 = Config.configure_task("t1", print, [data_node_1, data_node_2, data_node_3], [data_node_4])
  333. task_2 = Config.configure_task("t2", print, [data_node_4], None)
  334. task_3 = Config.configure_task("t3", print, [data_node_4], [data_node_5])
  335. task_4 = Config.configure_task("t4", print, None, output=[data_node_6])
  336. task_5 = Config.configure_task("t5", print, [data_node_7], None)
  337. scenario_cfg = Config.configure_scenario("scenario1", [task_5, task_3, task_4, task_2, task_1])
  338. # 12 | s7 __
  339. # 11 | \
  340. # 10 | \
  341. # 9 | t4 _ \_ t5
  342. # 8 | \ ____ t3 ___
  343. # 7 | \ / \
  344. # 6 | s3 _ \__ s6 _ s4 _/ \___ s5
  345. # 5 | \ / \
  346. # 4 | \ / \____ t2
  347. # 3 | s2 ___\__ t1 __/
  348. # 2 | /
  349. # 1 | /
  350. # 0 | s1 _/
  351. # |________________________________________________
  352. # 0 1 2 3 4
  353. scenario_cfg.draw("draw_3")