test_task_repositories.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright 2021-2025 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import pytest
  13. from taipy.core.data._data_fs_repository import _DataFSRepository
  14. from taipy.core.exceptions import ModelNotFound
  15. from taipy.core.task._task_fs_repository import _TaskFSRepository
  16. from taipy.core.task.task import Task, TaskId
  17. class TestTaskFSRepository:
  18. def test_save_and_load(self, data_node):
  19. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  20. data_repository._save(data_node)
  21. task = Task("task_config_id", {}, print, [data_node], [data_node])
  22. task_repository._save(task)
  23. loaded_task = task_repository._load(task.id)
  24. assert isinstance(loaded_task, Task)
  25. assert task._config_id == loaded_task._config_id
  26. assert task.id == loaded_task.id
  27. assert task._owner_id == loaded_task._owner_id
  28. assert task._parent_ids == loaded_task._parent_ids
  29. assert task._input == loaded_task._input
  30. assert task._output == loaded_task._output
  31. assert task._function == loaded_task._function
  32. assert task._version == loaded_task._version
  33. assert task._skippable == loaded_task._skippable
  34. assert task._properties == loaded_task._properties
  35. def test_exists(self, data_node):
  36. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  37. data_repository._save(data_node)
  38. task = Task("task_config_id", {}, print, [data_node], [data_node])
  39. task_repository._save(task)
  40. assert task_repository._exists(task.id)
  41. assert not task_repository._exists("not-existed-task")
  42. def test_load_all(self, data_node):
  43. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  44. data_repository._save(data_node)
  45. task = Task("task_config_id", {}, print, [data_node], [data_node])
  46. for i in range(10):
  47. task.id = TaskId(f"task-{i}")
  48. task_repository._save(task)
  49. data_nodes = task_repository._load_all()
  50. assert len(data_nodes) == 10
  51. def test_load_all_with_filters(self, data_node):
  52. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  53. data_repository._save(data_node)
  54. task = Task("task_config_id", {}, print, [data_node], [data_node])
  55. for i in range(10):
  56. task.id = TaskId(f"task-{i}")
  57. task._owner_id = f"owner-{i}"
  58. task_repository._save(task)
  59. objs = task_repository._load_all(filters=[{"owner_id": "owner-2"}])
  60. assert len(objs) == 1
  61. def test_delete(self, data_node):
  62. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  63. data_repository._save(data_node)
  64. task = Task("task_config_id", {}, print, [data_node], [data_node])
  65. task_repository._save(task)
  66. task_repository._delete(task.id)
  67. with pytest.raises(ModelNotFound):
  68. task_repository._load(task.id)
  69. def test_delete_all(self, data_node):
  70. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  71. data_repository._save(data_node)
  72. task = Task("task_config_id", {}, print, [data_node], [data_node])
  73. for i in range(10):
  74. task.id = TaskId(f"task-{i}")
  75. task_repository._save(task)
  76. assert len(task_repository._load_all()) == 10
  77. task_repository._delete_all()
  78. assert len(task_repository._load_all()) == 0
  79. def test_delete_many(self, data_node):
  80. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  81. data_repository._save(data_node)
  82. task = Task("task_config_id", {}, print, [data_node], [data_node])
  83. for i in range(10):
  84. task.id = TaskId(f"task-{i}")
  85. task_repository._save(task)
  86. objs = task_repository._load_all()
  87. assert len(objs) == 10
  88. ids = [x.id for x in objs[:3]]
  89. task_repository._delete_many(ids)
  90. assert len(task_repository._load_all()) == 7
  91. def test_delete_by(self, data_node):
  92. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  93. data_repository._save(data_node)
  94. task = Task("task_config_id", {}, print, [data_node], [data_node])
  95. # Create 5 entities with version 1.0 and 5 entities with version 2.0
  96. for i in range(10):
  97. task.id = TaskId(f"task-{i}")
  98. task._version = f"{(i+1) // 5}.0"
  99. task_repository._save(task)
  100. objs = task_repository._load_all()
  101. assert len(objs) == 10
  102. task_repository._delete_by("version", "1.0")
  103. assert len(task_repository._load_all()) == 5
  104. def test_search(self, data_node):
  105. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  106. data_repository._save(data_node)
  107. task = Task(
  108. "task_config_id",
  109. {},
  110. print,
  111. [data_node],
  112. [data_node],
  113. version="random_version_number",
  114. )
  115. for i in range(10):
  116. task.id = TaskId(f"task-{i}")
  117. task._owner_id = f"owner-{i}"
  118. task_repository._save(task)
  119. assert len(task_repository._load_all()) == 10
  120. objs = task_repository._search("owner_id", "owner-2")
  121. assert len(objs) == 1
  122. assert isinstance(objs[0], Task)
  123. objs = task_repository._search("owner_id", "owner-2", filters=[{"version": "random_version_number"}])
  124. assert len(objs) == 1
  125. assert isinstance(objs[0], Task)
  126. assert task_repository._search("owner_id", "owner-2", filters=[{"version": "non_existed_version"}]) == []
  127. def test_export(self, tmpdir, data_node):
  128. task_repository, data_repository = _TaskFSRepository(), _DataFSRepository()
  129. data_repository._save(data_node)
  130. task = Task("task_config_id", {}, print, [data_node], [data_node])
  131. task_repository._save(task)
  132. task_repository._export(task.id, tmpdir.strpath)
  133. dir_path = task_repository.dir_path
  134. assert os.path.exists(os.path.join(dir_path, f"{task.id}.json"))