test_data_repositories.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import pytest
  13. from taipy.core.data._data_fs_repository import _DataFSRepository
  14. from taipy.core.data._data_sql_repository import _DataSQLRepository
  15. from taipy.core.data.data_node import DataNode, DataNodeId
  16. from taipy.core.exceptions import ModelNotFound
  17. class TestDataNodeRepository:
  18. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  19. def test_save_and_load(self, data_node: DataNode, repo, init_sql_repo):
  20. repository = repo()
  21. repository._save(data_node)
  22. loaded_data_node = repository._load(data_node.id)
  23. assert isinstance(loaded_data_node, DataNode)
  24. assert data_node.id == loaded_data_node.id
  25. assert data_node._config_id == loaded_data_node._config_id
  26. assert data_node._owner_id == loaded_data_node._owner_id
  27. assert data_node._parent_ids == loaded_data_node._parent_ids
  28. assert data_node._scope == loaded_data_node._scope
  29. assert data_node._last_edit_date == loaded_data_node._last_edit_date
  30. assert data_node._edit_in_progress == loaded_data_node._edit_in_progress
  31. assert data_node._version == loaded_data_node._version
  32. assert data_node._validity_period == loaded_data_node._validity_period
  33. assert data_node._editor_id == loaded_data_node._editor_id
  34. assert data_node._editor_expiration_date == loaded_data_node._editor_expiration_date
  35. assert data_node._edits == loaded_data_node._edits
  36. assert data_node._properties == loaded_data_node._properties
  37. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  38. def test_exists(self, data_node, repo, init_sql_repo):
  39. repository = repo()
  40. repository._save(data_node)
  41. assert repository._exists(data_node.id)
  42. assert not repository._exists("not-existed-data-node")
  43. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  44. def test_load_all(self, data_node, repo, init_sql_repo):
  45. repository = repo()
  46. for i in range(10):
  47. data_node.id = DataNodeId(f"data_node-{i}")
  48. repository._save(data_node)
  49. data_nodes = repository._load_all()
  50. assert len(data_nodes) == 10
  51. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  52. def test_load_all_with_filters(self, data_node, repo, init_sql_repo):
  53. repository = repo()
  54. for i in range(10):
  55. data_node.id = DataNodeId(f"data_node-{i}")
  56. data_node._owner_id = f"task-{i}"
  57. repository._save(data_node)
  58. objs = repository._load_all(filters=[{"owner_id": "task-2"}])
  59. assert len(objs) == 1
  60. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  61. def test_delete(self, data_node, repo, init_sql_repo):
  62. repository = repo()
  63. repository._save(data_node)
  64. repository._delete(data_node.id)
  65. with pytest.raises(ModelNotFound):
  66. repository._load(data_node.id)
  67. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  68. def test_delete_all(self, data_node, repo, init_sql_repo):
  69. repository = repo()
  70. for i in range(10):
  71. data_node.id = DataNodeId(f"data_node-{i}")
  72. repository._save(data_node)
  73. assert len(repository._load_all()) == 10
  74. repository._delete_all()
  75. assert len(repository._load_all()) == 0
  76. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  77. def test_delete_many(self, data_node, repo, init_sql_repo):
  78. repository = repo()
  79. for i in range(10):
  80. data_node.id = DataNodeId(f"data_node-{i}")
  81. repository._save(data_node)
  82. objs = repository._load_all()
  83. assert len(objs) == 10
  84. ids = [x.id for x in objs[:3]]
  85. repository._delete_many(ids)
  86. assert len(repository._load_all()) == 7
  87. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  88. def test_delete_by(self, data_node, repo, init_sql_repo):
  89. repository = repo()
  90. # Create 5 entities with version 1.0 and 5 entities with version 2.0
  91. for i in range(10):
  92. data_node.id = DataNodeId(f"data_node-{i}")
  93. data_node._version = f"{(i+1) // 5}.0"
  94. repository._save(data_node)
  95. objs = repository._load_all()
  96. assert len(objs) == 10
  97. repository._delete_by("version", "1.0")
  98. assert len(repository._load_all()) == 5
  99. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  100. def test_search(self, data_node, repo, init_sql_repo):
  101. repository = repo()
  102. for i in range(10):
  103. data_node.id = DataNodeId(f"data_node-{i}")
  104. data_node._owner_id = f"task-{i}"
  105. repository._save(data_node)
  106. assert len(repository._load_all()) == 10
  107. objs = repository._search("owner_id", "task-2")
  108. assert len(objs) == 1
  109. assert isinstance(objs[0], DataNode)
  110. objs = repository._search("owner_id", "task-2", filters=[{"version": "random_version_number"}])
  111. assert len(objs) == 1
  112. assert isinstance(objs[0], DataNode)
  113. assert repository._search("owner_id", "task-2", filters=[{"version": "non_existed_version"}]) == []
  114. @pytest.mark.parametrize("repo", [_DataFSRepository, _DataSQLRepository])
  115. def test_export(self, tmpdir, data_node, repo, init_sql_repo):
  116. repository = repo()
  117. repository._save(data_node)
  118. repository._export(data_node.id, tmpdir.strpath)
  119. dir_path = repository.dir_path if repo == _DataFSRepository else os.path.join(tmpdir.strpath, "data_node")
  120. assert os.path.exists(os.path.join(dir_path, f"{data_node.id}.json"))