test_job_repositories.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import pytest
  13. from taipy.core.data._data_sql_repository import _DataSQLRepository
  14. from taipy.core.exceptions import ModelNotFound
  15. from taipy.core.job._job_fs_repository import _JobFSRepository
  16. from taipy.core.job._job_sql_repository import _JobSQLRepository
  17. from taipy.core.job.job import Job, JobId
  18. from taipy.core.task._task_sql_repository import _TaskSQLRepository
  19. from taipy.core.task.task import Task
  20. class TestJobRepository:
  21. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  22. def test_save_and_load(self, data_node, job, repo, init_sql_repo):
  23. _DataSQLRepository()._save(data_node)
  24. task = Task("task_config_id", {}, print, [data_node], [data_node])
  25. _TaskSQLRepository()._save(task)
  26. job._task = task
  27. repository = repo()
  28. repository._save(job)
  29. obj = repository._load(job.id)
  30. assert isinstance(obj, Job)
  31. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  32. def test_exists(self, data_node, job, repo, init_sql_repo):
  33. _DataSQLRepository()._save(data_node)
  34. task = Task("task_config_id", {}, print, [data_node], [data_node])
  35. _TaskSQLRepository()._save(task)
  36. job._task = task
  37. repository = repo()
  38. repository._save(job)
  39. assert repository._exists(job.id)
  40. assert not repository._exists("not-existed-job")
  41. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  42. def test_load_all(self, data_node, job, repo, init_sql_repo):
  43. _DataSQLRepository()._save(data_node)
  44. task = Task("task_config_id", {}, print, [data_node], [data_node])
  45. _TaskSQLRepository()._save(task)
  46. job._task = task
  47. repository = repo()
  48. for i in range(10):
  49. job.id = JobId(f"job-{i}")
  50. repository._save(job)
  51. jobs = repository._load_all()
  52. assert len(jobs) == 10
  53. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  54. def test_load_all_with_filters(self, data_node, job, repo, init_sql_repo):
  55. repository = repo()
  56. _DataSQLRepository()._save(data_node)
  57. task = Task("task_config_id", {}, print, [data_node], [data_node])
  58. _TaskSQLRepository()._save(task)
  59. job._task = task
  60. for i in range(10):
  61. job.id = JobId(f"job-{i}")
  62. repository._save(job)
  63. objs = repository._load_all(filters=[{"id": "job-2"}])
  64. assert len(objs) == 1
  65. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  66. def test_delete(self, data_node, job, repo, init_sql_repo):
  67. repository = repo()
  68. _DataSQLRepository()._save(data_node)
  69. task = Task("task_config_id", {}, print, [data_node], [data_node])
  70. _TaskSQLRepository()._save(task)
  71. job._task = task
  72. repository._save(job)
  73. repository._delete(job.id)
  74. with pytest.raises(ModelNotFound):
  75. repository._load(job.id)
  76. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  77. def test_delete_all(self, data_node, job, repo, init_sql_repo):
  78. repository = repo()
  79. _DataSQLRepository()._save(data_node)
  80. task = Task("task_config_id", {}, print, [data_node], [data_node])
  81. _TaskSQLRepository()._save(task)
  82. job._task = task
  83. for i in range(10):
  84. job.id = JobId(f"job-{i}")
  85. repository._save(job)
  86. assert len(repository._load_all()) == 10
  87. repository._delete_all()
  88. assert len(repository._load_all()) == 0
  89. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  90. def test_delete_many(self, data_node, job, repo, init_sql_repo):
  91. repository = repo()
  92. _DataSQLRepository()._save(data_node)
  93. task = Task("task_config_id", {}, print, [data_node], [data_node])
  94. _TaskSQLRepository()._save(task)
  95. job._task = task
  96. for i in range(10):
  97. job.id = JobId(f"job-{i}")
  98. repository._save(job)
  99. objs = repository._load_all()
  100. assert len(objs) == 10
  101. ids = [x.id for x in objs[:3]]
  102. repository._delete_many(ids)
  103. assert len(repository._load_all()) == 7
  104. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  105. def test_delete_by(self, data_node, job, repo, init_sql_repo):
  106. repository = repo()
  107. _DataSQLRepository()._save(data_node)
  108. task = Task("task_config_id", {}, print, [data_node], [data_node])
  109. _TaskSQLRepository()._save(task)
  110. job._task = task
  111. # Create 5 entities with version 1.0 and 5 entities with version 2.0
  112. for i in range(10):
  113. job.id = JobId(f"job-{i}")
  114. job._version = f"{(i+1) // 5}.0"
  115. repository._save(job)
  116. objs = repository._load_all()
  117. assert len(objs) == 10
  118. repository._delete_by("version", "1.0")
  119. assert len(repository._load_all()) == 5
  120. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  121. def test_search(self, data_node, job, repo, init_sql_repo):
  122. repository = repo()
  123. _DataSQLRepository()._save(data_node)
  124. task = Task("task_config_id", {}, print, [data_node], [data_node])
  125. _TaskSQLRepository()._save(task)
  126. job._task = task
  127. for i in range(10):
  128. job.id = JobId(f"job-{i}")
  129. repository._save(job)
  130. assert len(repository._load_all()) == 10
  131. objs = repository._search("id", "job-2")
  132. assert len(objs) == 1
  133. assert isinstance(objs[0], Job)
  134. objs = repository._search("id", "job-2", filters=[{"version": "random_version_number"}])
  135. assert len(objs) == 1
  136. assert isinstance(objs[0], Job)
  137. assert repository._search("id", "job-2", filters=[{"version": "non_existed_version"}]) == []
  138. @pytest.mark.parametrize("repo", [_JobFSRepository, _JobSQLRepository])
  139. def test_export(self, tmpdir, job, repo, init_sql_repo):
  140. repository = repo()
  141. repository._save(job)
  142. repository._export(job.id, tmpdir.strpath)
  143. dir_path = repository.dir_path if repo == _JobFSRepository else os.path.join(tmpdir.strpath, "job")
  144. assert os.path.exists(os.path.join(dir_path, f"{job.id}.json"))