test_scenario_repositories.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import pytest
  13. from taipy.core.exceptions import ModelNotFound
  14. from taipy.core.scenario._scenario_fs_repository import _ScenarioFSRepository
  15. from taipy.core.scenario._scenario_sql_repository import _ScenarioSQLRepository
  16. from taipy.core.scenario.scenario import Scenario, ScenarioId
  17. class TestScenarioFSRepository:
  18. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  19. def test_save_and_load(self, scenario: Scenario, repo, init_sql_repo):
  20. repository = repo()
  21. repository._save(scenario)
  22. loaded_scenario = repository._load(scenario.id)
  23. assert isinstance(loaded_scenario, Scenario)
  24. assert scenario._config_id == loaded_scenario._config_id
  25. assert scenario.id == loaded_scenario.id
  26. assert scenario._tasks == loaded_scenario._tasks
  27. assert scenario._additional_data_nodes == loaded_scenario._additional_data_nodes
  28. assert scenario._creation_date == loaded_scenario._creation_date
  29. assert scenario._cycle == loaded_scenario._cycle
  30. assert scenario._primary_scenario == loaded_scenario._primary_scenario
  31. assert scenario._tags == loaded_scenario._tags
  32. assert scenario._properties == loaded_scenario._properties
  33. assert scenario._sequences == loaded_scenario._sequences
  34. assert scenario._version == loaded_scenario._version
  35. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  36. def test_exists(self, scenario, repo, init_sql_repo):
  37. repository = repo()
  38. repository._save(scenario)
  39. assert repository._exists(scenario.id)
  40. assert not repository._exists("not-existed-scenario")
  41. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  42. def test_load_all(self, scenario, repo, init_sql_repo):
  43. repository = repo()
  44. for i in range(10):
  45. scenario.id = ScenarioId(f"scenario-{i}")
  46. repository._save(scenario)
  47. data_nodes = repository._load_all()
  48. assert len(data_nodes) == 10
  49. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  50. def test_load_all_with_filters(self, scenario, repo, init_sql_repo):
  51. repository = repo()
  52. for i in range(10):
  53. scenario.id = ScenarioId(f"scenario-{i}")
  54. repository._save(scenario)
  55. objs = repository._load_all(filters=[{"id": "scenario-2"}])
  56. assert len(objs) == 1
  57. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  58. def test_delete(self, scenario, repo, init_sql_repo):
  59. repository = repo()
  60. repository._save(scenario)
  61. repository._delete(scenario.id)
  62. with pytest.raises(ModelNotFound):
  63. repository._load(scenario.id)
  64. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  65. def test_delete_all(self, scenario, repo, init_sql_repo):
  66. repository = repo()
  67. for i in range(10):
  68. scenario.id = ScenarioId(f"scenario-{i}")
  69. repository._save(scenario)
  70. assert len(repository._load_all()) == 10
  71. repository._delete_all()
  72. assert len(repository._load_all()) == 0
  73. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  74. def test_delete_many(self, scenario, repo, init_sql_repo):
  75. repository = repo()
  76. for i in range(10):
  77. scenario.id = ScenarioId(f"scenario-{i}")
  78. repository._save(scenario)
  79. objs = repository._load_all()
  80. assert len(objs) == 10
  81. ids = [x.id for x in objs[:3]]
  82. repository._delete_many(ids)
  83. assert len(repository._load_all()) == 7
  84. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  85. def test_delete_by(self, scenario, repo, init_sql_repo):
  86. repository = repo()
  87. # Create 5 entities with version 1.0 and 5 entities with version 2.0
  88. for i in range(10):
  89. scenario.id = ScenarioId(f"scenario-{i}")
  90. scenario._version = f"{(i+1) // 5}.0"
  91. repository._save(scenario)
  92. objs = repository._load_all()
  93. assert len(objs) == 10
  94. repository._delete_by("version", "1.0")
  95. assert len(repository._load_all()) == 5
  96. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  97. def test_search(self, scenario, repo, init_sql_repo):
  98. repository = repo()
  99. for i in range(10):
  100. scenario.id = ScenarioId(f"scenario-{i}")
  101. repository._save(scenario)
  102. assert len(repository._load_all()) == 10
  103. objs = repository._search("id", "scenario-2")
  104. assert len(objs) == 1
  105. assert isinstance(objs[0], Scenario)
  106. objs = repository._search("id", "scenario-2", filters=[{"version": "random_version_number"}])
  107. assert len(objs) == 1
  108. assert isinstance(objs[0], Scenario)
  109. assert repository._search("id", "scenario-2", filters=[{"version": "non_existed_version"}]) == []
  110. @pytest.mark.parametrize("repo", [_ScenarioFSRepository, _ScenarioSQLRepository])
  111. def test_export(self, tmpdir, scenario, repo, init_sql_repo):
  112. repository = repo()
  113. repository._save(scenario)
  114. repository._export(scenario.id, tmpdir.strpath)
  115. dir_path = repository.dir_path if repo == _ScenarioFSRepository else os.path.join(tmpdir.strpath, "scenario")
  116. assert os.path.exists(os.path.join(dir_path, f"{scenario.id}.json"))