test_scenario_config.py 12 KB


  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. from unittest import mock
  13. from taipy.common.config import Config
  14. from taipy.common.config.common.frequency import Frequency
  15. from tests.core.utils.named_temporary_file import NamedTemporaryFile
  16. def my_func():
  17. pass
  18. def _configure_scenario_in_toml():
  19. return NamedTemporaryFile(
  20. content="""
  21. [TAIPY]
  22. [TASK.task1]
  23. inputs = []
  24. outputs = []
  25. [TASK.task2]
  26. inputs = []
  27. outputs = []
  28. [SCENARIO.scenarios1]
  29. tasks = [ "task1:SECTION", "task2:SECTION"]
  30. """
  31. )
  32. def _check_tasks_instance(task_id, scenario_id):
  33. """Check if the task instance in the task config correctly points to the Config._applied_config,
  34. not the Config._python_config or the Config._file_config
  35. """
  36. task_config_applied_instance = Config.tasks[task_id]
  37. task_config_instance_via_scenario = None
  38. for task in Config.scenarios[scenario_id].tasks:
  39. if task.id == task_id:
  40. task_config_instance_via_scenario = task
  41. task_config_python_instance = None
  42. if Config._python_config._sections.get("TASK", None):
  43. task_config_python_instance = Config._python_config._sections["TASK"][task_id]
  44. task_config_file_instance = None
  45. if Config._file_config._sections.get("TASK", None):
  46. task_config_file_instance = Config._file_config._sections["TASK"][task_id]
  47. assert task_config_python_instance is not task_config_applied_instance
  48. assert task_config_python_instance is not task_config_instance_via_scenario
  49. assert task_config_file_instance is not task_config_applied_instance
  50. assert task_config_file_instance is not task_config_instance_via_scenario
  51. assert task_config_instance_via_scenario is task_config_applied_instance
  52. def test_task_instance_when_configure_scenario_in_python():
  53. task1_config = Config.configure_task("task1", [])
  54. task2_config = Config.configure_task("task2", print)
  55. Config.configure_scenario("scenarios1", [task1_config, task2_config])
  56. _check_tasks_instance("task1", "scenarios1")
  57. _check_tasks_instance("task2", "scenarios1")
  58. def test_task_instance_when_configure_scenario_by_loading_toml():
  59. toml_config = _configure_scenario_in_toml()
  60. Config.load(toml_config.filename)
  61. _check_tasks_instance("task1", "scenarios1")
  62. _check_tasks_instance("task2", "scenarios1")
  63. def test_task_instance_when_configure_scenario_by_overriding_toml():
  64. toml_config = _configure_scenario_in_toml()
  65. Config.override(toml_config.filename)
  66. _check_tasks_instance("task1", "scenarios1")
  67. _check_tasks_instance("task2", "scenarios1")
  68. def test_scenario_creation():
  69. dn_config_1 = Config.configure_data_node("dn1")
  70. dn_config_2 = Config.configure_data_node("dn2")
  71. dn_config_3 = Config.configure_data_node("dn3")
  72. dn_config_4 = Config.configure_data_node("dn4")
  73. task_config_1 = Config.configure_task("task1", sum, [dn_config_1, dn_config_2], dn_config_3)
  74. task_config_2 = Config.configure_task("task2", print, dn_config_3)
  75. scenario = Config.configure_scenario(
  76. "scenarios1",
  77. [task_config_1, task_config_2],
  78. [dn_config_4],
  79. comparators={"dn_cfg": [my_func]},
  80. sequences={"sequence": []},
  81. )
  82. assert list(Config.scenarios) == ["default", scenario.id]
  83. scenario2 = Config.configure_scenario("scenarios2", [task_config_1], frequency=Frequency.MONTHLY)
  84. assert list(Config.scenarios) == ["default", scenario.id, scenario2.id]
  85. def test_scenario_count():
  86. task_config_1 = Config.configure_task("task1", my_func)
  87. task_config_2 = Config.configure_task("task2", print)
  88. Config.configure_scenario("scenarios1", [task_config_1, task_config_2])
  89. assert len(Config.scenarios) == 2
  90. Config.configure_scenario("scenarios2", [task_config_1])
  91. assert len(Config.scenarios) == 3
  92. Config.configure_scenario("scenarios3", [task_config_2])
  93. assert len(Config.scenarios) == 4
  94. def test_scenario_getitem():
  95. dn_config_1 = Config.configure_data_node("dn1")
  96. dn_config_2 = Config.configure_data_node("dn2")
  97. dn_config_3 = Config.configure_data_node("dn3")
  98. dn_config_4 = Config.configure_data_node("dn4")
  99. task_config_1 = Config.configure_task("task1", sum, [dn_config_1, dn_config_2], dn_config_3)
  100. task_config_2 = Config.configure_task("task2", print, dn_config_3)
  101. scenario_id = "scenarios1"
  102. scenario = Config.configure_scenario(scenario_id, [task_config_1, task_config_2], [dn_config_4])
  103. assert Config.scenarios[scenario_id].id == scenario.id
  104. assert Config.scenarios[scenario_id].task_configs == scenario.task_configs
  105. assert Config.scenarios[scenario_id].tasks == scenario.tasks
  106. assert Config.scenarios[scenario_id].task_configs == scenario.tasks
  107. assert Config.scenarios[scenario_id].additional_data_node_configs == scenario.additional_data_node_configs
  108. assert Config.scenarios[scenario_id].additional_data_nodes == scenario.additional_data_nodes
  109. assert Config.scenarios[scenario_id].additional_data_node_configs == scenario.additional_data_nodes
  110. assert Config.scenarios[scenario_id].data_node_configs == scenario.data_node_configs
  111. assert Config.scenarios[scenario_id].data_nodes == scenario.data_nodes
  112. assert Config.scenarios[scenario_id].data_node_configs == scenario.data_nodes
  113. assert scenario.tasks == [task_config_1, task_config_2]
  114. assert scenario.additional_data_node_configs == [dn_config_4]
  115. assert set(scenario.data_nodes) == {dn_config_4, dn_config_1, dn_config_2, dn_config_3}
  116. assert Config.scenarios[scenario_id].properties == scenario.properties
  117. def test_scenario_creation_no_duplication():
  118. task_config_1 = Config.configure_task("task1", my_func)
  119. task_config_2 = Config.configure_task("task2", print)
  120. dn_config = Config.configure_data_node("dn")
  121. Config.configure_scenario("scenarios1", [task_config_1, task_config_2], [dn_config])
  122. assert len(Config.scenarios) == 2
  123. Config.configure_scenario("scenarios1", [task_config_1, task_config_2], [dn_config])
  124. assert len(Config.scenarios) == 2
  125. def test_scenario_get_set_and_remove_comparators():
  126. task_config_1 = Config.configure_task("task1", my_func)
  127. task_config_2 = Config.configure_task("task2", print)
  128. dn_config_1 = "dn_config_1"
  129. scenario_config_1 = Config.configure_scenario(
  130. "scenarios1", [task_config_1, task_config_2], comparators={dn_config_1: my_func}
  131. )
  132. assert scenario_config_1.comparators is not None
  133. assert scenario_config_1.comparators[dn_config_1] == [my_func]
  134. assert len(scenario_config_1.comparators.keys()) == 1
  135. dn_config_2 = "dn_config_2"
  136. scenario_config_1.add_comparator(dn_config_2, my_func)
  137. assert len(scenario_config_1.comparators.keys()) == 2
  138. scenario_config_1.delete_comparator(dn_config_1)
  139. assert len(scenario_config_1.comparators.keys()) == 1
  140. scenario_config_1.delete_comparator(dn_config_2)
  141. assert len(scenario_config_1.comparators.keys()) == 0
  142. scenario_config_2 = Config.configure_scenario("scenarios2", [task_config_1, task_config_2])
  143. assert scenario_config_2.comparators is not None
  144. scenario_config_2.add_comparator(dn_config_1, my_func)
  145. assert len(scenario_config_2.comparators.keys()) == 1
  146. scenario_config_2.delete_comparator("dn_config_3")
  147. def test_scenario_config_with_env_variable_value():
  148. task_config_1 = Config.configure_task("task1", my_func)
  149. task_config_2 = Config.configure_task("task2", print)
  150. with mock.patch.dict(os.environ, {"FOO": "bar"}):
  151. Config.configure_scenario("scenario_name", [task_config_1, task_config_2], prop="ENV[FOO]")
  152. assert Config.scenarios["scenario_name"].prop == "bar"
  153. assert Config.scenarios["scenario_name"].properties["prop"] == "bar"
  154. assert Config.scenarios["scenario_name"]._properties["prop"] == "ENV[FOO]"
  155. def test_clean_config():
  156. task1_config = Config.configure_task("task1", print, [], [])
  157. task2_config = Config.configure_task("task2", print, [], [])
  158. scenario1_config = Config.configure_scenario(
  159. "id1",
  160. [task1_config, task2_config],
  161. [],
  162. Frequency.YEARLY,
  163. {"foo": "bar"},
  164. prop="foo",
  165. sequences={"sequence_1": []},
  166. )
  167. scenario2_config = Config.configure_scenario(
  168. "id2",
  169. [task2_config, task1_config],
  170. [],
  171. Frequency.MONTHLY,
  172. {"foz": "baz"},
  173. prop="bar",
  174. sequences={"sequence_2": []},
  175. )
  176. assert Config.scenarios["id1"] is scenario1_config
  177. assert Config.scenarios["id2"] is scenario2_config
  178. scenario1_config._clean()
  179. scenario2_config._clean()
  180. # Check if the instance before and after _clean() is the same
  181. assert Config.scenarios["id1"] is scenario1_config
  182. assert Config.scenarios["id2"] is scenario2_config
  183. assert scenario1_config.id == "id1"
  184. assert scenario2_config.id == "id2"
  185. assert scenario1_config.tasks == scenario1_config.task_configs == []
  186. assert scenario1_config.additional_data_nodes == scenario1_config.additional_data_node_configs == []
  187. assert scenario1_config.data_nodes == scenario1_config.data_node_configs == []
  188. assert scenario1_config.sequences == scenario1_config.sequences == {}
  189. assert scenario1_config.frequency is scenario1_config.frequency is None
  190. assert scenario1_config.comparators == scenario1_config.comparators == {}
  191. assert scenario1_config.properties == scenario1_config.properties == {}
  192. assert scenario2_config.tasks == scenario2_config.task_configs == []
  193. assert scenario2_config.additional_data_nodes == scenario2_config.additional_data_node_configs == []
  194. assert scenario2_config.data_nodes == scenario2_config.data_node_configs == []
  195. assert scenario2_config.sequences == scenario1_config.sequences == {}
  196. assert scenario2_config.frequency is scenario2_config.frequency is None
  197. assert scenario2_config.comparators == scenario2_config.comparators == {}
  198. assert scenario2_config.properties == scenario2_config.properties == {}
  199. def test_add_sequence():
  200. task1_config = Config.configure_task("task1", print, [], [])
  201. task2_config = Config.configure_task("task2", print, [], [])
  202. task3_config = Config.configure_task("task3", print, [], [])
  203. task4_config = Config.configure_task("task4", print, [], [])
  204. scenario_config = Config.configure_scenario(
  205. "id", [task1_config, task2_config, task3_config, task4_config], [], Frequency.YEARLY, prop="foo"
  206. )
  207. assert Config.scenarios["id"] is scenario_config
  208. assert scenario_config.id == "id"
  209. assert (
  210. scenario_config.tasks
  211. == scenario_config.task_configs
  212. == [task1_config, task2_config, task3_config, task4_config]
  213. )
  214. assert scenario_config.additional_data_nodes == scenario_config.additional_data_node_configs == []
  215. assert scenario_config.data_nodes == scenario_config.data_node_configs == []
  216. assert scenario_config.frequency is scenario_config.frequency == Frequency.YEARLY
  217. assert scenario_config.comparators == scenario_config.comparators == {}
  218. assert scenario_config.properties == {"prop": "foo"}
  219. scenario_config.add_sequences(
  220. {
  221. "sequence1": [task1_config],
  222. "sequence2": [task2_config, task3_config],
  223. "sequence3": [task1_config, task2_config, task4_config],
  224. }
  225. )
  226. assert len(scenario_config.sequences) == 3
  227. assert scenario_config.sequences["sequence1"] == [task1_config]
  228. assert scenario_config.sequences["sequence2"] == [task2_config, task3_config]
  229. assert scenario_config.sequences["sequence3"] == [task1_config, task2_config, task4_config]
  230. scenario_config.remove_sequences("sequence1")
  231. assert len(scenario_config.sequences) == 2
  232. scenario_config.remove_sequences(["sequence2", "sequence3"])
  233. assert len(scenario_config.sequences) == 0