test_task_repositories.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import pytest
  13. from taipy.config.config import Config
  14. from taipy.core.data._data_fs_repository import _DataFSRepository
  15. from taipy.core.data._data_sql_repository import _DataSQLRepository
  16. from taipy.core.exceptions import ModelNotFound
  17. from taipy.core.task._task_fs_repository import _TaskFSRepository
  18. from taipy.core.task._task_sql_repository import _TaskSQLRepository
  19. from taipy.core.task.task import Task, TaskId
  20. class TestTaskFSRepository:
  21. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  22. def test_save_and_load(self, data_node, repo, tmp_sqlite):
  23. if repo[1] == _DataSQLRepository:
  24. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  25. task_repository, data_repository = repo[0](), repo[1]()
  26. data_repository._save(data_node)
  27. task = Task("task_config_id", {}, print, [data_node], [data_node])
  28. task_repository._save(task)
  29. loaded_task = task_repository._load(task.id)
  30. assert isinstance(loaded_task, Task)
  31. assert task._config_id == loaded_task._config_id
  32. assert task.id == loaded_task.id
  33. assert task._owner_id == loaded_task._owner_id
  34. assert task._parent_ids == loaded_task._parent_ids
  35. assert task._input == loaded_task._input
  36. assert task._output == loaded_task._output
  37. assert task._function == loaded_task._function
  38. assert task._version == loaded_task._version
  39. assert task._skippable == loaded_task._skippable
  40. assert task._properties == loaded_task._properties
  41. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  42. def test_exists(self, data_node, repo, tmp_sqlite):
  43. if repo[1] == _DataSQLRepository:
  44. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  45. task_repository, data_repository = repo[0](), repo[1]()
  46. data_repository._save(data_node)
  47. task = Task("task_config_id", {}, print, [data_node], [data_node])
  48. task_repository._save(task)
  49. assert task_repository._exists(task.id)
  50. assert not task_repository._exists("not-existed-task")
  51. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  52. def test_load_all(self, data_node, repo, tmp_sqlite):
  53. if repo[1] == _DataSQLRepository:
  54. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  55. task_repository, data_repository = repo[0](), repo[1]()
  56. data_repository._save(data_node)
  57. task = Task("task_config_id", {}, print, [data_node], [data_node])
  58. for i in range(10):
  59. task.id = TaskId(f"task-{i}")
  60. task_repository._save(task)
  61. data_nodes = task_repository._load_all()
  62. assert len(data_nodes) == 10
  63. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  64. def test_load_all_with_filters(self, data_node, repo, tmp_sqlite):
  65. if repo[1] == _DataSQLRepository:
  66. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  67. task_repository, data_repository = repo[0](), repo[1]()
  68. data_repository._save(data_node)
  69. task = Task("task_config_id", {}, print, [data_node], [data_node])
  70. for i in range(10):
  71. task.id = TaskId(f"task-{i}")
  72. task._owner_id = f"owner-{i}"
  73. task_repository._save(task)
  74. objs = task_repository._load_all(filters=[{"owner_id": "owner-2"}])
  75. assert len(objs) == 1
  76. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  77. def test_delete(self, data_node, repo, tmp_sqlite):
  78. if repo[1] == _DataSQLRepository:
  79. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  80. task_repository, data_repository = repo[0](), repo[1]()
  81. data_repository._save(data_node)
  82. task = Task("task_config_id", {}, print, [data_node], [data_node])
  83. task_repository._save(task)
  84. task_repository._delete(task.id)
  85. with pytest.raises(ModelNotFound):
  86. task_repository._load(task.id)
  87. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  88. def test_delete_all(self, data_node, repo, tmp_sqlite):
  89. if repo[1] == _DataSQLRepository:
  90. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  91. task_repository, data_repository = repo[0](), repo[1]()
  92. data_repository._save(data_node)
  93. task = Task("task_config_id", {}, print, [data_node], [data_node])
  94. for i in range(10):
  95. task.id = TaskId(f"task-{i}")
  96. task_repository._save(task)
  97. assert len(task_repository._load_all()) == 10
  98. task_repository._delete_all()
  99. assert len(task_repository._load_all()) == 0
  100. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  101. def test_delete_many(self, data_node, repo, tmp_sqlite):
  102. if repo[1] == _DataSQLRepository:
  103. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  104. task_repository, data_repository = repo[0](), repo[1]()
  105. data_repository._save(data_node)
  106. task = Task("task_config_id", {}, print, [data_node], [data_node])
  107. for i in range(10):
  108. task.id = TaskId(f"task-{i}")
  109. task_repository._save(task)
  110. objs = task_repository._load_all()
  111. assert len(objs) == 10
  112. ids = [x.id for x in objs[:3]]
  113. task_repository._delete_many(ids)
  114. assert len(task_repository._load_all()) == 7
  115. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  116. def test_delete_by(self, data_node, repo, tmp_sqlite):
  117. if repo[1] == _DataSQLRepository:
  118. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  119. task_repository, data_repository = repo[0](), repo[1]()
  120. data_repository._save(data_node)
  121. task = Task("task_config_id", {}, print, [data_node], [data_node])
  122. # Create 5 entities with version 1.0 and 5 entities with version 2.0
  123. for i in range(10):
  124. task.id = TaskId(f"task-{i}")
  125. task._version = f"{(i+1) // 5}.0"
  126. task_repository._save(task)
  127. objs = task_repository._load_all()
  128. assert len(objs) == 10
  129. task_repository._delete_by("version", "1.0")
  130. assert len(task_repository._load_all()) == 5
  131. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  132. def test_search(self, data_node, repo, tmp_sqlite):
  133. if repo[1] == _DataSQLRepository:
  134. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  135. task_repository, data_repository = repo[0](), repo[1]()
  136. data_repository._save(data_node)
  137. task = Task("task_config_id", {}, print, [data_node], [data_node], version="random_version_number")
  138. for i in range(10):
  139. task.id = TaskId(f"task-{i}")
  140. task._owner_id = f"owner-{i}"
  141. task_repository._save(task)
  142. assert len(task_repository._load_all()) == 10
  143. objs = task_repository._search("owner_id", "owner-2")
  144. assert len(objs) == 1
  145. assert isinstance(objs[0], Task)
  146. objs = task_repository._search("owner_id", "owner-2", filters=[{"version": "random_version_number"}])
  147. assert len(objs) == 1
  148. assert isinstance(objs[0], Task)
  149. assert task_repository._search("owner_id", "owner-2", filters=[{"version": "non_existed_version"}]) == []
  150. @pytest.mark.parametrize("repo", [(_TaskFSRepository, _DataFSRepository), (_TaskSQLRepository, _DataSQLRepository)])
  151. def test_export(self, tmpdir, data_node, repo, tmp_sqlite):
  152. if repo[1] == _DataSQLRepository:
  153. Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
  154. task_repository, data_repository = repo[0](), repo[1]()
  155. data_repository._save(data_node)
  156. task = Task("task_config_id", {}, print, [data_node], [data_node])
  157. task_repository._save(task)
  158. task_repository._export(task.id, tmpdir.strpath)
  159. dir_path = task_repository.dir_path if repo[0] == _TaskFSRepository else os.path.join(tmpdir.strpath, "task")
  160. assert os.path.exists(os.path.join(dir_path, f"{task.id}.json"))