test_task_config.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. from unittest import mock
  13. from taipy.common.config import Config
  14. from taipy.core.config import DataNodeConfig
  15. from tests.core.utils.named_temporary_file import NamedTemporaryFile
  16. def _configure_task_in_toml():
  17. return NamedTemporaryFile(
  18. content="""
  19. [TAIPY]
  20. [DATA_NODE.input]
  21. [DATA_NODE.output]
  22. [TASK.tasks1]
  23. function = "builtins.print:function"
  24. inputs = [ "input:SECTION",]
  25. outputs = [ "output:SECTION",]
  26. """
  27. )
  28. def _check_data_nodes_instance(dn_id, task_id):
  29. """Check if the data node instance in the task config correctly points to the Config._applied_config,
  30. not the Config._python_config or the Config._file_config
  31. """
  32. dn_config_applied_instance = Config.data_nodes[dn_id]
  33. for dn in Config.tasks[task_id].inputs:
  34. if dn.id == dn_id:
  35. dn_config_instance_via_task = dn
  36. for dn in Config.tasks[task_id].outputs:
  37. if dn.id == dn_id:
  38. dn_config_instance_via_task = dn
  39. dn_config_python_instance = None
  40. if Config._python_config._sections.get("DATA_NODE", None):
  41. dn_config_python_instance = Config._python_config._sections["DATA_NODE"][dn_id]
  42. dn_config_file_instance = None
  43. if Config._file_config._sections.get("DATA_NODE", None):
  44. dn_config_file_instance = Config._file_config._sections["DATA_NODE"][dn_id]
  45. if dn_config_python_instance:
  46. assert dn_config_python_instance.scope is None
  47. assert dn_config_python_instance is not dn_config_applied_instance
  48. assert dn_config_python_instance is not dn_config_instance_via_task
  49. if dn_config_file_instance:
  50. assert dn_config_file_instance.scope is None
  51. assert dn_config_file_instance is not dn_config_applied_instance
  52. assert dn_config_file_instance is not dn_config_instance_via_task
  53. assert dn_config_applied_instance.scope == DataNodeConfig._DEFAULT_SCOPE
  54. assert dn_config_instance_via_task is dn_config_applied_instance
  55. def test_data_node_instance_when_configure_task_in_python():
  56. input_config = Config.configure_data_node("input")
  57. output_config = Config.configure_data_node("output")
  58. Config.configure_task("tasks1", print, input_config, output_config)
  59. _check_data_nodes_instance("input", "tasks1")
  60. _check_data_nodes_instance("output", "tasks1")
  61. def test_data_node_instance_when_configure_task_by_loading_toml():
  62. toml_config = _configure_task_in_toml()
  63. Config.load(toml_config.filename)
  64. _check_data_nodes_instance("input", "tasks1")
  65. _check_data_nodes_instance("output", "tasks1")
  66. def test_data_node_instance_when_configure_task_by_overriding_toml():
  67. toml_config = _configure_task_in_toml()
  68. Config.override(toml_config.filename)
  69. _check_data_nodes_instance("input", "tasks1")
  70. _check_data_nodes_instance("output", "tasks1")
  71. def test_task_config_creation():
  72. input_config = Config.configure_data_node("input")
  73. output_config = Config.configure_data_node("output")
  74. task_config = Config.configure_task("tasks1", print, input_config, output_config)
  75. assert not task_config.skippable
  76. assert list(Config.tasks) == ["default", task_config.id]
  77. task2 = Config.configure_task("tasks2", print, input_config, output_config, skippable=True)
  78. assert task2.skippable
  79. assert list(Config.tasks) == ["default", task_config.id, task2.id]
  80. def test_task_count():
  81. input_config = Config.configure_data_node("input")
  82. output_config = Config.configure_data_node("output")
  83. Config.configure_task("tasks1", print, input_config, output_config)
  84. assert len(Config.tasks) == 2
  85. Config.configure_task("tasks2", print, input_config, output_config)
  86. assert len(Config.tasks) == 3
  87. Config.configure_task("tasks3", print, input_config, output_config)
  88. assert len(Config.tasks) == 4
  89. def test_task_getitem():
  90. input_config = Config.configure_data_node("input")
  91. output_config = Config.configure_data_node("output")
  92. task_id = "tasks1"
  93. task_cfg = Config.configure_task(task_id, print, input_config, output_config)
  94. assert Config.tasks[task_id].id == task_cfg.id
  95. assert Config.tasks[task_id].properties == task_cfg.properties
  96. assert Config.tasks[task_id].function == task_cfg.function
  97. assert Config.tasks[task_id].input_configs == task_cfg.input_configs
  98. assert Config.tasks[task_id].output_configs == task_cfg.output_configs
  99. assert Config.tasks[task_id].skippable == task_cfg.skippable
  100. def test_task_creation_no_duplication():
  101. input_config = Config.configure_data_node("input")
  102. output_config = Config.configure_data_node("output")
  103. Config.configure_task("tasks1", print, input_config, output_config)
  104. assert len(Config.tasks) == 2
  105. Config.configure_task("tasks1", print, input_config, output_config)
  106. assert len(Config.tasks) == 2
  107. def test_task_config_with_env_variable_value():
  108. input_config = Config.configure_data_node("input")
  109. output_config = Config.configure_data_node("output")
  110. with mock.patch.dict(os.environ, {"FOO": "plop", "BAR": "baz"}):
  111. Config.configure_task("task_name", print, input_config, output_config, prop="ENV[BAR]")
  112. assert Config.tasks["task_name"].prop == "baz"
  113. assert Config.tasks["task_name"].properties["prop"] == "baz"
  114. assert Config.tasks["task_name"]._properties["prop"] == "ENV[BAR]"
  115. def test_clean_config():
  116. dn1 = Config.configure_data_node("dn1")
  117. dn2 = Config.configure_data_node("dn2")
  118. task1_config = Config.configure_task("id1", print, dn1, dn2)
  119. task2_config = Config.configure_task("id2", print, dn2, dn1)
  120. assert Config.tasks["id1"] is task1_config
  121. assert Config.tasks["id2"] is task2_config
  122. task1_config._clean()
  123. task2_config._clean()
  124. # Check if the instance before and after _clean() is the same
  125. assert Config.tasks["id1"] is task1_config
  126. assert Config.tasks["id2"] is task2_config
  127. assert task1_config.id == "id1"
  128. assert task2_config.id == "id2"
  129. assert task1_config.function is task1_config.function is None
  130. assert task1_config.inputs == task1_config.inputs == []
  131. assert task1_config.input_configs == task1_config.input_configs == []
  132. assert task1_config.outputs == task1_config.outputs == []
  133. assert task1_config.output_configs == task1_config.output_configs == []
  134. assert task1_config.skippable is task1_config.skippable is False
  135. assert task1_config.properties == task1_config.properties == {}